Browse Source

Rework TLS buffers to avoid overflowing.

If we're keeping some bytes in the line buffer waiting for a
CR/LF, the SSLEngine may not have enough room to put a packet's
worth of data in.

Instead of managing a single buffer, keep a channel of them
and emit a complete buffer each read call.
tags/v1.0.1
Chris Smith 5 years ago
parent
commit
51b19e41b5

+ 3
- 0
CHANGELOG View File

@@ -1,5 +1,8 @@
1 1
 vNEXT (in development)
2 2
 
3
+ * Fixed issue with very long packets not fitting in buffers
4
+   after TLS decryption.
5
+
3 6
 v1.0.0
4 7
 
5 8
  * Replaced Ktor dependency with custom socket handling, which fixes

+ 25
- 14
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -7,6 +7,7 @@ import kotlinx.coroutines.channels.ReceiveChannel
7 7
 import kotlinx.coroutines.channels.SendChannel
8 8
 import kotlinx.coroutines.channels.produce
9 9
 import kotlinx.coroutines.io.ByteWriteChannel
10
+import kotlinx.io.core.String
10 11
 import java.net.InetSocketAddress
11 12
 import java.nio.ByteBuffer
12 13
 import java.security.SecureRandom
@@ -73,23 +74,33 @@ internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private va
73 74
 
74 75
     override val receiveChannel
75 76
         get() = produce {
76
-            val lineBuffer = ByteArray(16384)
77
-            var nextByteOffset = 0
78
-            while (socket.isOpen) {
79
-                var lineStart = 0
80
-                val bytesRead = socket.read(ByteBuffer.wrap(lineBuffer).apply { position(nextByteOffset) })
81
-                for (i in nextByteOffset until nextByteOffset + bytesRead) {
82
-                    if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
83
-                        if (lineStart < i) {
84
-                            val line = lineBuffer.sliceArray(lineStart until i)
85
-                            log.fine { "<<< ${String(line)}" }
86
-                            send(line)
77
+            defaultPool.borrow { lineBuffer ->
78
+                while (socket.isOpen) {
79
+                    defaultPool.borrow { buffer ->
80
+                        val bytesRead = socket.read(buffer)
81
+                        var lastLine = 0
82
+                        for (i in 0 until bytesRead) {
83
+                            if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
84
+                                val length = i - lastLine + lineBuffer.position()
85
+
86
+                                if (length > 1) {
87
+                                    val output = ByteBuffer.allocate(length)
88
+
89
+                                    lineBuffer.flip()
90
+                                    output.put(lineBuffer)
91
+                                    lineBuffer.clear()
92
+
93
+                                    output.put(buffer.array(), lastLine, i - lastLine)
94
+                                    log.fine { "<<< ${String(output.array())}" }
95
+                                    send(output.array())
96
+                                }
97
+
98
+                                lastLine = i + 1
99
+                            }
87 100
                         }
88
-                        lineStart = i + 1
101
+                        lineBuffer.put(buffer.array(), lastLine, bytesRead - lastLine)
89 102
                     }
90 103
                 }
91
-                lineBuffer.copyInto(lineBuffer, 0, lineStart)
92
-                nextByteOffset += bytesRead - lineStart
93 104
             }
94 105
         }
95 106
 

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

@@ -15,7 +15,7 @@ import kotlin.coroutines.resume
15 15
 import kotlin.coroutines.resumeWithException
16 16
 
17 17
 internal const val BUFFER_SIZE = 32768
18
-internal const val POOL_SIZE = 16
18
+internal const val POOL_SIZE = 128
19 19
 
20 20
 internal val defaultPool = ByteBufferPool()
21 21
 

+ 30
- 16
src/main/kotlin/com/dmdirc/ktirc/io/Tls.kt View File

@@ -1,11 +1,11 @@
1 1
 package com.dmdirc.ktirc.io
2 2
 
3
+import com.dmdirc.ktirc.util.logger
3 4
 import kotlinx.coroutines.CoroutineScope
5
+import kotlinx.coroutines.channels.Channel
4 6
 import kotlinx.coroutines.io.ByteChannel
5 7
 import kotlinx.coroutines.io.ByteWriteChannel
6 8
 import kotlinx.coroutines.launch
7
-import kotlinx.coroutines.sync.Mutex
8
-import kotlinx.coroutines.sync.withLock
9 9
 import java.net.SocketAddress
10 10
 import java.nio.ByteBuffer
11 11
 import java.security.cert.CertificateException
@@ -25,12 +25,12 @@ internal class TlsSocket(
25 25
         private val hostname: String
26 26
 ) : Socket {
27 27
 
28
+    private val log by logger()
28 29
     private var engine: SSLEngine = sslContext.createSSLEngine()
29 30
 
30 31
     private var incomingNetBuffer = ByteBuffer.allocate(0)
31
-    private var outgoingAppBuffer = ByteBuffer.allocate(0)
32 32
     private var incomingAppBuffer = ByteBuffer.allocate(0)
33
-    private val outgoingAppBufferMutex = Mutex(false)
33
+    private var outgoingAppBuffers = Channel<ByteBuffer>(capacity = Channel.UNLIMITED)
34 34
 
35 35
     private var writeChannel = ByteChannel(autoFlush = true)
36 36
 
@@ -52,7 +52,7 @@ internal class TlsSocket(
52 52
         }
53 53
 
54 54
         incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
55
-        outgoingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
55
+        outgoingAppBuffers = Channel(capacity = Channel.UNLIMITED)
56 56
         incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
57 57
 
58 58
         socket.connect(socketAddress)
@@ -96,11 +96,11 @@ internal class TlsSocket(
96 96
         }
97 97
     }
98 98
 
99
-    override suspend fun read(buffer: ByteBuffer) = outgoingAppBufferMutex.withLock<Int> {
100
-        outgoingAppBuffer.flip()
101
-        val bytes = outgoingAppBuffer.limit()
102
-        buffer.put(outgoingAppBuffer)
103
-        outgoingAppBuffer.clear()
99
+    override suspend fun read(buffer: ByteBuffer): Int {
100
+        val nextBuffer = outgoingAppBuffers.receive()
101
+        val bytes = nextBuffer.limit()
102
+        buffer.put(nextBuffer)
103
+        defaultPool.recycle(nextBuffer)
104 104
         return bytes
105 105
     }
106 106
 
@@ -120,8 +120,8 @@ internal class TlsSocket(
120 120
         return result
121 121
     }
122 122
 
123
-    private suspend fun unwrap(): SSLEngineResult? {
124
-        if (incomingNetBuffer.position() == 0) {
123
+    private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
124
+        if (networkRead) {
125 125
             val bytes = socket.read(incomingNetBuffer.slice())
126 126
             if (bytes == -1) {
127 127
                 close()
@@ -131,15 +131,29 @@ internal class TlsSocket(
131 131
         }
132 132
 
133 133
         incomingNetBuffer.flip()
134
-        outgoingAppBufferMutex.withLock {
135
-            val result = engine.unwrap(incomingNetBuffer, outgoingAppBuffer)
136
-            incomingNetBuffer.compact()
137
-            return result
134
+
135
+        val buffer = defaultPool.borrow()
136
+        val result = engine.unwrap(incomingNetBuffer, buffer)
137
+        incomingNetBuffer.compact()
138
+        if (buffer.position() > 0) {
139
+            buffer.flip()
140
+            outgoingAppBuffers.send(buffer)
141
+        } else {
142
+            defaultPool.recycle(buffer)
143
+        }
144
+
145
+        return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
146
+            // We didn't do a network read, but SSLEngine is unhappy; force a read.
147
+            log.finest { "Incoming net buffer underflowed, forcing re-read" }
148
+            unwrap(true)
149
+        } else {
150
+            result
138 151
         }
139 152
     }
140 153
 
141 154
     override fun close() {
142 155
         socket.close()
156
+        outgoingAppBuffers.close()
143 157
     }
144 158
 
145 159
     private suspend fun readLoop() {

Loading…
Cancel
Save