Browse Source

Tidy up and reduce copying of byte buffers

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
622e2e2539

+ 3
- 0
CHANGELOG View File

1
 vNEXT (in development)
1
 vNEXT (in development)
2
 
2
 
3
+ * (Internal) improved the way byte buffers are used to
4
+   reduce array copying and clean up code
5
+
3
 v1.0.1
6
 v1.0.1
4
 
7
 
5
  * Fixed issue with very long packets not fitting in buffers
8
  * Fixed issue with very long packets not fitting in buffers

+ 14
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Buffers.kt View File

1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.io.pool.DefaultPool
4
+import java.nio.ByteBuffer
5
+
6
+private const val BUFFER_SIZE = 32768
7
+private const val MAXIMUM_POOL_SIZE = 1024
8
+
9
+internal val byteBufferPool = ByteBufferPool(MAXIMUM_POOL_SIZE, BUFFER_SIZE)
10
+
11
+internal class ByteBufferPool(maximumPoolSize: Int, private val bufferSize: Int) : DefaultPool<ByteBuffer>(maximumPoolSize) {
12
+    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(bufferSize)
13
+    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
14
+}

+ 23
- 26
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

8
 import kotlinx.coroutines.channels.produce
8
 import kotlinx.coroutines.channels.produce
9
 import kotlinx.coroutines.io.ByteWriteChannel
9
 import kotlinx.coroutines.io.ByteWriteChannel
10
 import kotlinx.io.core.String
10
 import kotlinx.io.core.String
11
+import kotlinx.io.pool.useInstance
11
 import java.net.InetSocketAddress
12
 import java.net.InetSocketAddress
12
 import java.nio.ByteBuffer
13
 import java.nio.ByteBuffer
13
 import java.security.SecureRandom
14
 import java.security.SecureRandom
55
 
56
 
56
         runBlocking {
57
         runBlocking {
57
             if (tls) {
58
             if (tls) {
58
-                with (SSLContext.getInstance("TLSv1.2")) {
59
+                with(SSLContext.getInstance("TLSv1.2")) {
59
                     init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
60
                     init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
60
                     socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
61
                     socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
61
                 }
62
                 }
74
 
75
 
75
     override val receiveChannel
76
     override val receiveChannel
76
         get() = produce {
77
         get() = produce {
77
-            defaultPool.borrow { lineBuffer ->
78
+            byteBufferPool.useInstance { lineBuffer ->
78
                 while (socket.isOpen) {
79
                 while (socket.isOpen) {
79
-                    defaultPool.borrow { buffer ->
80
-                        val bytesRead = socket.read(buffer)
81
-                        if (bytesRead == -1) {
82
-                            return@produce
83
-                        }
84
-                        var lastLine = 0
85
-                        for (i in 0 until bytesRead) {
86
-                            if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
87
-                                val length = i - lastLine + lineBuffer.position()
88
-
89
-                                if (length > 1) {
90
-                                    val output = ByteBuffer.allocate(length)
91
-
92
-                                    lineBuffer.flip()
93
-                                    output.put(lineBuffer)
94
-                                    lineBuffer.clear()
95
-
96
-                                    output.put(buffer.array(), lastLine, i - lastLine)
97
-                                    log.fine { "<<< ${String(output.array())}" }
98
-                                    send(output.array())
99
-                                }
100
-
101
-                                lastLine = i + 1
80
+                    val buffer = socket.read() ?: return@produce
81
+                    var lastLine = 0
82
+                    for (i in 0 until buffer.limit()) {
83
+                        if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
84
+                            val length = i - lastLine + lineBuffer.position()
85
+
86
+                            if (length > 1) {
87
+                                val output = ByteBuffer.allocate(length)
88
+
89
+                                lineBuffer.flip()
90
+                                output.put(lineBuffer)
91
+                                lineBuffer.clear()
92
+
93
+                                output.put(buffer.array(), lastLine, i - lastLine)
94
+                                log.fine { "<<< ${String(output.array())}" }
95
+                                send(output.array())
102
                             }
96
                             }
97
+
98
+                            lastLine = i + 1
103
                         }
99
                         }
104
-                        lineBuffer.put(buffer.array(), lastLine, bytesRead - lastLine)
105
                     }
100
                     }
101
+                    lineBuffer.put(buffer.array(), lastLine, buffer.limit() - lastLine)
102
+                    byteBufferPool.recycle(buffer)
106
                 }
103
                 }
107
             }
104
             }
108
         }
105
         }

+ 9
- 26
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

7
 import kotlinx.coroutines.io.close
7
 import kotlinx.coroutines.io.close
8
 import kotlinx.coroutines.launch
8
 import kotlinx.coroutines.launch
9
 import kotlinx.coroutines.suspendCancellableCoroutine
9
 import kotlinx.coroutines.suspendCancellableCoroutine
10
-import kotlinx.io.pool.DefaultPool
10
+import kotlinx.io.pool.useInstance
11
 import java.net.SocketAddress
11
 import java.net.SocketAddress
12
 import java.nio.ByteBuffer
12
 import java.nio.ByteBuffer
13
 import java.nio.channels.*
13
 import java.nio.channels.*
14
 import kotlin.coroutines.resume
14
 import kotlin.coroutines.resume
15
 import kotlin.coroutines.resumeWithException
15
 import kotlin.coroutines.resumeWithException
16
 
16
 
17
-internal const val BUFFER_SIZE = 32768
18
-internal const val POOL_SIZE = 128
19
-
20
-internal val defaultPool = ByteBufferPool()
21
-
22
-internal class ByteBufferPool : DefaultPool<ByteBuffer>(POOL_SIZE) {
23
-    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(BUFFER_SIZE)
24
-    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
25
-
26
-    inline fun <T> borrow(block: (ByteBuffer) -> T): T {
27
-        val buffer = borrow()
28
-        try {
29
-            return block(buffer)
30
-        } finally {
31
-            recycle(buffer)
32
-        }
33
-    }
34
-}
35
-
36
 internal interface Socket {
17
 internal interface Socket {
37
     fun bind(socketAddress: SocketAddress)
18
     fun bind(socketAddress: SocketAddress)
38
     suspend fun connect(socketAddress: SocketAddress)
19
     suspend fun connect(socketAddress: SocketAddress)
39
-    suspend fun read(buffer: ByteBuffer): Int
20
+    suspend fun read(): ByteBuffer?
40
     fun close()
21
     fun close()
41
     val write: ByteWriteChannel
22
     val write: ByteWriteChannel
42
     val isOpen: Boolean
23
     val isOpen: Boolean
73
         client.close()
54
         client.close()
74
     }
55
     }
75
 
56
 
76
-    override suspend fun read(buffer: ByteBuffer) = try {
57
+    override suspend fun read() = try {
58
+        val buffer = byteBufferPool.borrow()
77
         val bytes = suspendCancellableCoroutine<Int> { continuation ->
59
         val bytes = suspendCancellableCoroutine<Int> { continuation ->
78
-
79
             client.closeOnCancel(continuation)
60
             client.closeOnCancel(continuation)
80
             client.read(buffer, continuation, asyncIOHandler())
61
             client.read(buffer, continuation, asyncIOHandler())
81
         }
62
         }
83
         if (bytes == -1) {
64
         if (bytes == -1) {
84
             close()
65
             close()
85
         }
66
         }
86
-        bytes
67
+
68
+        buffer.flip()
69
+        buffer
87
     } catch (_: ClosedChannelException) {
70
     } catch (_: ClosedChannelException) {
88
         // Ignore
71
         // Ignore
89
-        0
72
+        null
90
     }
73
     }
91
 
74
 
92
     private suspend fun writeLoop() {
75
     private suspend fun writeLoop() {
93
         while (client.isOpen) {
76
         while (client.isOpen) {
94
-            defaultPool.borrow { buffer ->
77
+            byteBufferPool.useInstance { buffer ->
95
                 writeChannel.readAvailable(buffer)
78
                 writeChannel.readAvailable(buffer)
96
                 buffer.flip()
79
                 buffer.flip()
97
                 try {
80
                 try {

+ 13
- 15
src/main/kotlin/com/dmdirc/ktirc/io/Tls.kt View File

7
 import kotlinx.coroutines.io.ByteChannel
7
 import kotlinx.coroutines.io.ByteChannel
8
 import kotlinx.coroutines.io.ByteWriteChannel
8
 import kotlinx.coroutines.io.ByteWriteChannel
9
 import kotlinx.coroutines.launch
9
 import kotlinx.coroutines.launch
10
+import kotlinx.io.pool.useInstance
10
 import java.net.SocketAddress
11
 import java.net.SocketAddress
11
 import java.nio.ByteBuffer
12
 import java.nio.ByteBuffer
12
 import java.security.cert.CertificateException
13
 import java.security.cert.CertificateException
97
         }
98
         }
98
     }
99
     }
99
 
100
 
100
-    override suspend fun read(buffer: ByteBuffer) = try {
101
-        val nextBuffer = outgoingAppBuffers.receive()
102
-        val bytes = nextBuffer.limit()
103
-        buffer.put(nextBuffer)
104
-        defaultPool.recycle(nextBuffer)
105
-        bytes
101
+    override suspend fun read() = try {
102
+        outgoingAppBuffers.receive()
106
     } catch (_: ClosedReceiveChannelException) {
103
     } catch (_: ClosedReceiveChannelException) {
107
-        -1
104
+        null
108
     }
105
     }
109
 
106
 
110
     private suspend fun wrap(): SSLEngineResult? {
107
     private suspend fun wrap(): SSLEngineResult? {
111
         var result: SSLEngineResult? = null
108
         var result: SSLEngineResult? = null
112
-        defaultPool.borrow { netBuffer ->
109
+        byteBufferPool.useInstance { netBuffer ->
113
             if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
110
             if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
114
                 writeChannel.readAvailable(incomingAppBuffer)
111
                 writeChannel.readAvailable(incomingAppBuffer)
115
             }
112
             }
125
 
122
 
126
     private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
123
     private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
127
         if (networkRead) {
124
         if (networkRead) {
128
-            val bytes = socket.read(incomingNetBuffer.slice())
129
-            if (bytes == -1) {
125
+            val buffer = socket.read()
126
+            if (buffer == null) {
130
                 close()
127
                 close()
131
                 return null
128
                 return null
132
             }
129
             }
133
-            incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
130
+            incomingNetBuffer.put(buffer)
131
+            byteBufferPool.recycle(buffer)
134
         }
132
         }
135
 
133
 
136
         incomingNetBuffer.flip()
134
         incomingNetBuffer.flip()
137
 
135
 
138
-        val buffer = defaultPool.borrow()
136
+        val buffer = byteBufferPool.borrow()
139
         val result = engine.unwrap(incomingNetBuffer, buffer)
137
         val result = engine.unwrap(incomingNetBuffer, buffer)
140
         incomingNetBuffer.compact()
138
         incomingNetBuffer.compact()
141
         if (buffer.position() > 0) {
139
         if (buffer.position() > 0) {
142
             buffer.flip()
140
             buffer.flip()
143
             outgoingAppBuffers.send(buffer)
141
             outgoingAppBuffers.send(buffer)
144
         } else {
142
         } else {
145
-            defaultPool.recycle(buffer)
143
+            byteBufferPool.recycle(buffer)
146
         }
144
         }
147
 
145
 
148
         return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
146
         return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
158
         socket.close()
156
         socket.close()
159
 
157
 
160
         // Release any buffers we've got queued up
158
         // Release any buffers we've got queued up
161
-        while(true) {
159
+        while (true) {
162
             outgoingAppBuffers.poll()?.let {
160
             outgoingAppBuffers.poll()?.let {
163
-                defaultPool.recycle(it)
161
+                byteBufferPool.recycle(it)
164
             } ?: break
162
             } ?: break
165
         }
163
         }
166
 
164
 

+ 62
- 0
src/test/kotlin/com/dmdirc/ktirc/io/ByteBufferPoolTest.kt View File

1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.io.pool.useInstance
4
+import org.junit.jupiter.api.Assertions.*
5
+import org.junit.jupiter.api.Test
6
+import java.nio.ByteBuffer
7
+
8
+internal class ByteBufferPoolTest {
9
+
10
+    @Test
11
+    fun `it allows borrowing of multiple unique bytebuffers`() {
12
+        val pool = ByteBufferPool(5, 10)
13
+        val buffer1 = pool.borrow()
14
+        val buffer2 = pool.borrow()
15
+        val buffer3 = pool.borrow()
16
+
17
+        assertFalse(buffer1 === buffer2)
18
+        assertFalse(buffer2 === buffer3)
19
+        assertFalse(buffer1 === buffer3)
20
+    }
21
+
22
+    @Test
23
+    fun `it produces buffers of the correct size`() {
24
+        val pool = ByteBufferPool(5, 12)
25
+        val buffer = pool.borrow()
26
+        assertEquals(12, buffer.limit())
27
+    }
28
+
29
+    @Test
30
+    fun `it reuses recycled buffers`() {
31
+        val pool = ByteBufferPool(1, 10)
32
+
33
+        val buffer1 = pool.borrow()
34
+        pool.recycle(buffer1)
35
+
36
+        val buffer2 = pool.borrow()
37
+        assertTrue(buffer1 === buffer2)
38
+    }
39
+
40
+    @Test
41
+    fun `it resets buffers when reborrowing`() {
42
+        val pool = ByteBufferPool(1, 10)
43
+        val buffer1 = pool.borrow()
44
+        buffer1.put("31137".toByteArray())
45
+        pool.recycle(buffer1)
46
+
47
+        val buffer2 = pool.borrow()
48
+        assertTrue(buffer1 === buffer2)
49
+        assertEquals(0, buffer2.position())
50
+        assertEquals(10, buffer2.limit())
51
+    }
52
+
53
+    @Test
54
+    fun `borrow with block automatically returns`() {
55
+        val pool = ByteBufferPool(1, 10)
56
+        var buffer1: ByteBuffer? = null
57
+        pool.useInstance {  buffer1 = it }
58
+        val buffer2 = pool.borrow()
59
+        assertTrue(buffer1 === buffer2)
60
+    }
61
+
62
+}

Loading…
Cancel
Save