Browse Source

Add some TLS tests, fix a leaky buffer

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
28c9400250

+ 16
- 13
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

54
         client.close()
54
         client.close()
55
     }
55
     }
56
 
56
 
57
-    override suspend fun read() = try {
57
+    override suspend fun read(): ByteBuffer? {
58
         val buffer = byteBufferPool.borrow()
58
         val buffer = byteBufferPool.borrow()
59
-        val bytes = suspendCancellableCoroutine<Int> { continuation ->
60
-            client.closeOnCancel(continuation)
61
-            client.read(buffer, continuation, asyncIOHandler())
62
-        }
59
+        try {
60
+            val bytes = suspendCancellableCoroutine<Int> { continuation ->
61
+                client.closeOnCancel(continuation)
62
+                client.read(buffer, continuation, asyncIOHandler())
63
+            }
63
 
64
 
64
-        if (bytes == -1) {
65
-            close()
66
-        }
65
+            if (bytes == -1) {
66
+                close()
67
+            }
67
 
68
 
68
-        buffer.flip()
69
-        buffer
70
-    } catch (_: ClosedChannelException) {
71
-        // Ignore
72
-        null
69
+            buffer.flip()
70
+            return buffer
71
+        } catch (_: ClosedChannelException) {
72
+            // Ignore
73
+            byteBufferPool.recycle(buffer)
74
+            return null
75
+        }
73
     }
76
     }
74
 
77
 
75
     private suspend fun writeLoop() {
78
     private suspend fun writeLoop() {

+ 11
- 36
src/test/kotlin/com/dmdirc/ktirc/io/LineBufferedSocketImplTest.kt View File

7
 import org.junit.jupiter.api.parallel.Execution
7
 import org.junit.jupiter.api.parallel.Execution
8
 import org.junit.jupiter.api.parallel.ExecutionMode
8
 import org.junit.jupiter.api.parallel.ExecutionMode
9
 import java.net.ServerSocket
9
 import java.net.ServerSocket
10
-import java.security.KeyStore
11
-import java.security.cert.X509Certificate
12
-import javax.net.ssl.KeyManagerFactory
13
-import javax.net.ssl.SSLContext
14
-import javax.net.ssl.X509TrustManager
15
 
10
 
11
+@Suppress("BlockingMethodInNonBlockingContext")
16
 @ExperimentalCoroutinesApi
12
 @ExperimentalCoroutinesApi
17
 @Execution(ExecutionMode.SAME_THREAD)
13
 @Execution(ExecutionMode.SAME_THREAD)
18
 internal class LineBufferedSocketImplTest {
14
 internal class LineBufferedSocketImplTest {
19
 
15
 
20
     @Test
16
     @Test
21
-    fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
17
+    fun `can connect to a server`() = runBlocking {
22
         ServerSocket(12321).use { serverSocket ->
18
         ServerSocket(12321).use { serverSocket ->
23
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
19
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
24
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
20
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
30
     }
26
     }
31
 
27
 
32
     @Test
28
     @Test
33
-    fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
29
+    fun `can send a byte array to a server`() = runBlocking {
34
         ServerSocket(12321).use { serverSocket ->
30
         ServerSocket(12321).use { serverSocket ->
35
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
31
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
36
             val clientBytesAsync = GlobalScope.async {
32
             val clientBytesAsync = GlobalScope.async {
49
     }
45
     }
50
 
46
 
51
     @Test
47
     @Test
52
-    fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
48
+    fun `can send a string to a server over TLS`() = runBlocking {
53
         tlsServerSocket(12321).use { serverSocket ->
49
         tlsServerSocket(12321).use { serverSocket ->
54
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321, true)
50
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321, true)
55
             socket.tlsTrustManager = getTrustingManager()
51
             socket.tlsTrustManager = getTrustingManager()
69
     }
65
     }
70
 
66
 
71
     @Test
67
     @Test
72
-    fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
68
+    fun `can receive a line of CRLF delimited text`() = runBlocking {
73
         ServerSocket(12321).use { serverSocket ->
69
         ServerSocket(12321).use { serverSocket ->
74
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
70
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
75
             GlobalScope.launch {
71
             GlobalScope.launch {
82
     }
78
     }
83
 
79
 
84
     @Test
80
     @Test
85
-    fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
81
+    fun `can receive a line of LF delimited text`() = runBlocking {
86
         ServerSocket(12321).use { serverSocket ->
82
         ServerSocket(12321).use { serverSocket ->
87
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
83
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
88
             GlobalScope.launch {
84
             GlobalScope.launch {
95
     }
91
     }
96
 
92
 
97
     @Test
93
     @Test
98
-    fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
94
+    fun `can receive multiple lines of text in one packet`() = runBlocking {
99
         ServerSocket(12321).use { serverSocket ->
95
         ServerSocket(12321).use { serverSocket ->
100
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
96
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
101
             GlobalScope.launch {
97
             GlobalScope.launch {
110
     }
106
     }
111
 
107
 
112
     @Test
108
     @Test
113
-    fun `KtorLineBufferedSocket can receive multiple long lines of text`() = runBlocking {
109
+    fun `can receive multiple long lines of text`() = runBlocking {
114
         ServerSocket(12321).use { serverSocket ->
110
         ServerSocket(12321).use { serverSocket ->
115
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
111
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
116
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
112
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
129
     }
125
     }
130
 
126
 
131
     @Test
127
     @Test
132
-    fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
128
+    fun `can receive one line of text over multiple packets`() = runBlocking {
133
         ServerSocket(12321).use { serverSocket ->
129
         ServerSocket(12321).use { serverSocket ->
134
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
130
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
135
             GlobalScope.launch {
131
             GlobalScope.launch {
150
     }
146
     }
151
 
147
 
152
     @Test
148
     @Test
153
-    fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
149
+    fun `returns from readLines when socket is closed`() = runBlocking {
154
         ServerSocket(12321).use { serverSocket ->
150
         ServerSocket(12321).use { serverSocket ->
155
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
151
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
156
             GlobalScope.launch {
152
             GlobalScope.launch {
167
     }
163
     }
168
 
164
 
169
     @Test
165
     @Test
170
-    fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
166
+    fun `disconnects from server`() = runBlocking {
171
         ServerSocket(12321).use { serverSocket ->
167
         ServerSocket(12321).use { serverSocket ->
172
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
168
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
173
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
169
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
179
         }
175
         }
180
     }
176
     }
181
 
177
 
182
-    private fun tlsServerSocket(port: Int): ServerSocket {
183
-        val keyStore = KeyStore.getInstance("PKCS12")
184
-        keyStore.load(LineBufferedSocketImplTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
185
-
186
-        val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
187
-        keyManagerFactory.init(keyStore, CharArray(0))
188
-
189
-        val sslContext = SSLContext.getInstance("TLSv1.2")
190
-        sslContext.init(keyManagerFactory.keyManagers, null, null)
191
-        return sslContext.serverSocketFactory.createServerSocket(port)
192
-    }
193
-
194
-    private fun getTrustingManager() = object : X509TrustManager {
195
-        override fun getAcceptedIssuers(): Array<X509Certificate>  = emptyArray()
196
-
197
-        override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
198
-
199
-        override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
200
-    }
201
-
202
-
203
 }
178
 }

+ 86
- 2
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

2
 
2
 
3
 import io.mockk.every
3
 import io.mockk.every
4
 import io.mockk.mockk
4
 import io.mockk.mockk
5
-import org.junit.jupiter.api.Assertions.assertFalse
6
-import org.junit.jupiter.api.Assertions.assertTrue
5
+import kotlinx.coroutines.GlobalScope
6
+import kotlinx.coroutines.async
7
+import kotlinx.coroutines.io.writeFully
8
+import kotlinx.coroutines.launch
9
+import kotlinx.coroutines.runBlocking
10
+import kotlinx.io.core.String
11
+import org.junit.jupiter.api.Assertions
12
+import org.junit.jupiter.api.Assertions.*
7
 import org.junit.jupiter.api.Test
13
 import org.junit.jupiter.api.Test
14
+import org.junit.jupiter.api.parallel.Execution
15
+import org.junit.jupiter.api.parallel.ExecutionMode
16
+import java.net.InetSocketAddress
17
+import java.net.ServerSocket
18
+import java.security.KeyStore
8
 import java.security.cert.CertificateException
19
 import java.security.cert.CertificateException
9
 import java.security.cert.X509Certificate
20
 import java.security.cert.X509Certificate
21
+import javax.net.ssl.KeyManagerFactory
22
+import javax.net.ssl.SSLContext
23
+import javax.net.ssl.X509TrustManager
10
 
24
 
11
 internal class CertificateValidationTest {
25
 internal class CertificateValidationTest {
12
 
26
 
149
         assertFalse(cert.validFor("directory.test.ktirc"))
163
         assertFalse(cert.validFor("directory.test.ktirc"))
150
     }
164
     }
151
 
165
 
166
+}
167
+
168
+@Suppress("BlockingMethodInNonBlockingContext")
169
+@Execution(ExecutionMode.SAME_THREAD)
170
+internal class TlsSocketTest {
171
+
172
+    @Test
173
+    fun `can send a string to a server over TLS`() = runBlocking {
174
+        tlsServerSocket(12321).use { serverSocket ->
175
+            val plainSocket = PlainTextSocket(GlobalScope)
176
+            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
177
+            val clientBytesAsync = GlobalScope.async {
178
+                ByteArray(13).apply {
179
+                    serverSocket.accept().getInputStream().read(this)
180
+                }
181
+            }
182
+
183
+            tlsSocket.connect(InetSocketAddress("localhost", 12321))
184
+            tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
185
+
186
+            val bytes = clientBytesAsync.await()
187
+            Assertions.assertNotNull(bytes)
188
+            Assertions.assertEquals("Hello World\r\n", String(bytes))
189
+        }
190
+    }
191
+
192
+    @Test
193
+    fun `throws if the hostname mismatches`() {
194
+        tlsServerSocket(12321).use { serverSocket ->
195
+            val plainSocket = PlainTextSocket(GlobalScope)
196
+            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
197
+            GlobalScope.launch {
198
+                serverSocket.accept().getInputStream().read()
199
+            }
200
+
201
+            runBlocking {
202
+                try {
203
+                    tlsSocket.connect(InetSocketAddress("localhost", 12321))
204
+                    fail<Unit>("Expected an exception")
205
+                } catch (ex: Exception) {
206
+                    assertTrue(ex is CertificateException)
207
+                }
208
+            }
209
+        }
210
+    }
211
+
212
+
213
+}
214
+
215
+internal fun tlsServerSocket(port: Int): ServerSocket {
216
+    val keyStore = KeyStore.getInstance("PKCS12")
217
+    keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
218
+
219
+    val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
220
+    keyManagerFactory.init(keyStore, CharArray(0))
221
+
222
+    val sslContext = SSLContext.getInstance("TLSv1.2")
223
+    sslContext.init(keyManagerFactory.keyManagers, null, null)
224
+    return sslContext.serverSocketFactory.createServerSocket(port)
225
+}
226
+
227
+internal fun getTrustingContext() =
228
+        SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
229
+
230
+internal fun getTrustingManager() = object : X509TrustManager {
231
+    override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
232
+
233
+    override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
234
+
235
+    override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
152
 }
236
 }

Loading…
Cancel
Save