Browse Source

Tidy up and reduce copying of byte buffers

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
622e2e2539

+ 3
- 0
CHANGELOG View File

@@ -1,5 +1,8 @@
1 1
 vNEXT (in development)
2 2
 
3
+ * (Internal) improved the way byte buffers are used to
4
+   reduce array copying and clean up code
5
+
3 6
 v1.0.1
4 7
 
5 8
  * Fixed issue with very long packets not fitting in buffers

+ 14
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Buffers.kt View File

@@ -0,0 +1,14 @@
1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.io.pool.DefaultPool
4
+import java.nio.ByteBuffer
5
+
6
+private const val BUFFER_SIZE = 32768
7
+private const val MAXIMUM_POOL_SIZE = 1024
8
+
9
+internal val byteBufferPool = ByteBufferPool(MAXIMUM_POOL_SIZE, BUFFER_SIZE)
10
+
11
+internal class ByteBufferPool(maximumPoolSize: Int, private val bufferSize: Int) : DefaultPool<ByteBuffer>(maximumPoolSize) {
12
+    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(bufferSize)
13
+    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
14
+}

+ 23
- 26
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -8,6 +8,7 @@ import kotlinx.coroutines.channels.SendChannel
8 8
 import kotlinx.coroutines.channels.produce
9 9
 import kotlinx.coroutines.io.ByteWriteChannel
10 10
 import kotlinx.io.core.String
11
+import kotlinx.io.pool.useInstance
11 12
 import java.net.InetSocketAddress
12 13
 import java.nio.ByteBuffer
13 14
 import java.security.SecureRandom
@@ -55,7 +56,7 @@ internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private va
55 56
 
56 57
         runBlocking {
57 58
             if (tls) {
58
-                with (SSLContext.getInstance("TLSv1.2")) {
59
+                with(SSLContext.getInstance("TLSv1.2")) {
59 60
                     init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
60 61
                     socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
61 62
                 }
@@ -74,35 +75,31 @@ internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private va
74 75
 
75 76
     override val receiveChannel
76 77
         get() = produce {
77
-            defaultPool.borrow { lineBuffer ->
78
+            byteBufferPool.useInstance { lineBuffer ->
78 79
                 while (socket.isOpen) {
79
-                    defaultPool.borrow { buffer ->
80
-                        val bytesRead = socket.read(buffer)
81
-                        if (bytesRead == -1) {
82
-                            return@produce
83
-                        }
84
-                        var lastLine = 0
85
-                        for (i in 0 until bytesRead) {
86
-                            if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
87
-                                val length = i - lastLine + lineBuffer.position()
88
-
89
-                                if (length > 1) {
90
-                                    val output = ByteBuffer.allocate(length)
91
-
92
-                                    lineBuffer.flip()
93
-                                    output.put(lineBuffer)
94
-                                    lineBuffer.clear()
95
-
96
-                                    output.put(buffer.array(), lastLine, i - lastLine)
97
-                                    log.fine { "<<< ${String(output.array())}" }
98
-                                    send(output.array())
99
-                                }
100
-
101
-                                lastLine = i + 1
80
+                    val buffer = socket.read() ?: return@produce
81
+                    var lastLine = 0
82
+                    for (i in 0 until buffer.limit()) {
83
+                        if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
84
+                            val length = i - lastLine + lineBuffer.position()
85
+
86
+                            if (length > 1) {
87
+                                val output = ByteBuffer.allocate(length)
88
+
89
+                                lineBuffer.flip()
90
+                                output.put(lineBuffer)
91
+                                lineBuffer.clear()
92
+
93
+                                output.put(buffer.array(), lastLine, i - lastLine)
94
+                                log.fine { "<<< ${String(output.array())}" }
95
+                                send(output.array())
102 96
                             }
97
+
98
+                            lastLine = i + 1
103 99
                         }
104
-                        lineBuffer.put(buffer.array(), lastLine, bytesRead - lastLine)
105 100
                     }
101
+                    lineBuffer.put(buffer.array(), lastLine, buffer.limit() - lastLine)
102
+                    byteBufferPool.recycle(buffer)
106 103
                 }
107 104
             }
108 105
         }

+ 9
- 26
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

@@ -7,36 +7,17 @@ import kotlinx.coroutines.io.ByteWriteChannel
7 7
 import kotlinx.coroutines.io.close
8 8
 import kotlinx.coroutines.launch
9 9
 import kotlinx.coroutines.suspendCancellableCoroutine
10
-import kotlinx.io.pool.DefaultPool
10
+import kotlinx.io.pool.useInstance
11 11
 import java.net.SocketAddress
12 12
 import java.nio.ByteBuffer
13 13
 import java.nio.channels.*
14 14
 import kotlin.coroutines.resume
15 15
 import kotlin.coroutines.resumeWithException
16 16
 
17
-internal const val BUFFER_SIZE = 32768
18
-internal const val POOL_SIZE = 128
19
-
20
-internal val defaultPool = ByteBufferPool()
21
-
22
-internal class ByteBufferPool : DefaultPool<ByteBuffer>(POOL_SIZE) {
23
-    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(BUFFER_SIZE)
24
-    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
25
-
26
-    inline fun <T> borrow(block: (ByteBuffer) -> T): T {
27
-        val buffer = borrow()
28
-        try {
29
-            return block(buffer)
30
-        } finally {
31
-            recycle(buffer)
32
-        }
33
-    }
34
-}
35
-
36 17
 internal interface Socket {
37 18
     fun bind(socketAddress: SocketAddress)
38 19
     suspend fun connect(socketAddress: SocketAddress)
39
-    suspend fun read(buffer: ByteBuffer): Int
20
+    suspend fun read(): ByteBuffer?
40 21
     fun close()
41 22
     val write: ByteWriteChannel
42 23
     val isOpen: Boolean
@@ -73,9 +54,9 @@ internal class PlainTextSocket(private val scope: CoroutineScope) : Socket {
73 54
         client.close()
74 55
     }
75 56
 
76
-    override suspend fun read(buffer: ByteBuffer) = try {
57
+    override suspend fun read() = try {
58
+        val buffer = byteBufferPool.borrow()
77 59
         val bytes = suspendCancellableCoroutine<Int> { continuation ->
78
-
79 60
             client.closeOnCancel(continuation)
80 61
             client.read(buffer, continuation, asyncIOHandler())
81 62
         }
@@ -83,15 +64,17 @@ internal class PlainTextSocket(private val scope: CoroutineScope) : Socket {
83 64
         if (bytes == -1) {
84 65
             close()
85 66
         }
86
-        bytes
67
+
68
+        buffer.flip()
69
+        buffer
87 70
     } catch (_: ClosedChannelException) {
88 71
         // Ignore
89
-        0
72
+        null
90 73
     }
91 74
 
92 75
     private suspend fun writeLoop() {
93 76
         while (client.isOpen) {
94
-            defaultPool.borrow { buffer ->
77
+            byteBufferPool.useInstance { buffer ->
95 78
                 writeChannel.readAvailable(buffer)
96 79
                 buffer.flip()
97 80
                 try {

+ 13
- 15
src/main/kotlin/com/dmdirc/ktirc/io/Tls.kt View File

@@ -7,6 +7,7 @@ import kotlinx.coroutines.channels.ClosedReceiveChannelException
7 7
 import kotlinx.coroutines.io.ByteChannel
8 8
 import kotlinx.coroutines.io.ByteWriteChannel
9 9
 import kotlinx.coroutines.launch
10
+import kotlinx.io.pool.useInstance
10 11
 import java.net.SocketAddress
11 12
 import java.nio.ByteBuffer
12 13
 import java.security.cert.CertificateException
@@ -97,19 +98,15 @@ internal class TlsSocket(
97 98
         }
98 99
     }
99 100
 
100
-    override suspend fun read(buffer: ByteBuffer) = try {
101
-        val nextBuffer = outgoingAppBuffers.receive()
102
-        val bytes = nextBuffer.limit()
103
-        buffer.put(nextBuffer)
104
-        defaultPool.recycle(nextBuffer)
105
-        bytes
101
+    override suspend fun read() = try {
102
+        outgoingAppBuffers.receive()
106 103
     } catch (_: ClosedReceiveChannelException) {
107
-        -1
104
+        null
108 105
     }
109 106
 
110 107
     private suspend fun wrap(): SSLEngineResult? {
111 108
         var result: SSLEngineResult? = null
112
-        defaultPool.borrow { netBuffer ->
109
+        byteBufferPool.useInstance { netBuffer ->
113 110
             if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
114 111
                 writeChannel.readAvailable(incomingAppBuffer)
115 112
             }
@@ -125,24 +122,25 @@ internal class TlsSocket(
125 122
 
126 123
     private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
127 124
         if (networkRead) {
128
-            val bytes = socket.read(incomingNetBuffer.slice())
129
-            if (bytes == -1) {
125
+            val buffer = socket.read()
126
+            if (buffer == null) {
130 127
                 close()
131 128
                 return null
132 129
             }
133
-            incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
130
+            incomingNetBuffer.put(buffer)
131
+            byteBufferPool.recycle(buffer)
134 132
         }
135 133
 
136 134
         incomingNetBuffer.flip()
137 135
 
138
-        val buffer = defaultPool.borrow()
136
+        val buffer = byteBufferPool.borrow()
139 137
         val result = engine.unwrap(incomingNetBuffer, buffer)
140 138
         incomingNetBuffer.compact()
141 139
         if (buffer.position() > 0) {
142 140
             buffer.flip()
143 141
             outgoingAppBuffers.send(buffer)
144 142
         } else {
145
-            defaultPool.recycle(buffer)
143
+            byteBufferPool.recycle(buffer)
146 144
         }
147 145
 
148 146
         return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
@@ -158,9 +156,9 @@ internal class TlsSocket(
158 156
         socket.close()
159 157
 
160 158
         // Release any buffers we've got queued up
161
-        while(true) {
159
+        while (true) {
162 160
             outgoingAppBuffers.poll()?.let {
163
-                defaultPool.recycle(it)
161
+                byteBufferPool.recycle(it)
164 162
             } ?: break
165 163
         }
166 164
 

+ 62
- 0
src/test/kotlin/com/dmdirc/ktirc/io/ByteBufferPoolTest.kt View File

@@ -0,0 +1,62 @@
1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.io.pool.useInstance
4
+import org.junit.jupiter.api.Assertions.*
5
+import org.junit.jupiter.api.Test
6
+import java.nio.ByteBuffer
7
+
8
+internal class ByteBufferPoolTest {
9
+
10
+    @Test
11
+    fun `it allows borrowing of multiple unique bytebuffers`() {
12
+        val pool = ByteBufferPool(5, 10)
13
+        val buffer1 = pool.borrow()
14
+        val buffer2 = pool.borrow()
15
+        val buffer3 = pool.borrow()
16
+
17
+        assertFalse(buffer1 === buffer2)
18
+        assertFalse(buffer2 === buffer3)
19
+        assertFalse(buffer1 === buffer3)
20
+    }
21
+
22
+    @Test
23
+    fun `it produces buffers of the correct size`() {
24
+        val pool = ByteBufferPool(5, 12)
25
+        val buffer = pool.borrow()
26
+        assertEquals(12, buffer.limit())
27
+    }
28
+
29
+    @Test
30
+    fun `it reuses recycled buffers`() {
31
+        val pool = ByteBufferPool(1, 10)
32
+
33
+        val buffer1 = pool.borrow()
34
+        pool.recycle(buffer1)
35
+
36
+        val buffer2 = pool.borrow()
37
+        assertTrue(buffer1 === buffer2)
38
+    }
39
+
40
+    @Test
41
+    fun `it resets buffers when reborrowing`() {
42
+        val pool = ByteBufferPool(1, 10)
43
+        val buffer1 = pool.borrow()
44
+        buffer1.put("31137".toByteArray())
45
+        pool.recycle(buffer1)
46
+
47
+        val buffer2 = pool.borrow()
48
+        assertTrue(buffer1 === buffer2)
49
+        assertEquals(0, buffer2.position())
50
+        assertEquals(10, buffer2.limit())
51
+    }
52
+
53
+    @Test
54
+    fun `borrow with block automatically returns`() {
55
+        val pool = ByteBufferPool(1, 10)
56
+        var buffer1: ByteBuffer? = null
57
+        pool.useInstance {  buffer1 = it }
58
+        val buffer2 = pool.borrow()
59
+        assertTrue(buffer1 === buffer2)
60
+    }
61
+
62
+}

Loading…
Cancel
Save