Browse Source

Add some TLS tests, fix a leaky buffer

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
28c9400250

+ 16
- 13
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

@@ -54,22 +54,25 @@ internal class PlainTextSocket(private val scope: CoroutineScope) : Socket {
54 54
         client.close()
55 55
     }
56 56
 
57
-    override suspend fun read() = try {
57
+    override suspend fun read(): ByteBuffer? {
58 58
         val buffer = byteBufferPool.borrow()
59
-        val bytes = suspendCancellableCoroutine<Int> { continuation ->
60
-            client.closeOnCancel(continuation)
61
-            client.read(buffer, continuation, asyncIOHandler())
62
-        }
59
+        try {
60
+            val bytes = suspendCancellableCoroutine<Int> { continuation ->
61
+                client.closeOnCancel(continuation)
62
+                client.read(buffer, continuation, asyncIOHandler())
63
+            }
63 64
 
64
-        if (bytes == -1) {
65
-            close()
66
-        }
65
+            if (bytes == -1) {
66
+                close()
67
+            }
67 68
 
68
-        buffer.flip()
69
-        buffer
70
-    } catch (_: ClosedChannelException) {
71
-        // Ignore
72
-        null
69
+            buffer.flip()
70
+            return buffer
71
+        } catch (_: ClosedChannelException) {
72
+            // Ignore
73
+            byteBufferPool.recycle(buffer)
74
+            return null
75
+        }
73 76
     }
74 77
 
75 78
     private suspend fun writeLoop() {

+ 11
- 36
src/test/kotlin/com/dmdirc/ktirc/io/LineBufferedSocketImplTest.kt View File

@@ -7,18 +7,14 @@ import org.junit.jupiter.api.Test
7 7
 import org.junit.jupiter.api.parallel.Execution
8 8
 import org.junit.jupiter.api.parallel.ExecutionMode
9 9
 import java.net.ServerSocket
10
-import java.security.KeyStore
11
-import java.security.cert.X509Certificate
12
-import javax.net.ssl.KeyManagerFactory
13
-import javax.net.ssl.SSLContext
14
-import javax.net.ssl.X509TrustManager
15 10
 
11
+@Suppress("BlockingMethodInNonBlockingContext")
16 12
 @ExperimentalCoroutinesApi
17 13
 @Execution(ExecutionMode.SAME_THREAD)
18 14
 internal class LineBufferedSocketImplTest {
19 15
 
20 16
     @Test
21
-    fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
17
+    fun `can connect to a server`() = runBlocking {
22 18
         ServerSocket(12321).use { serverSocket ->
23 19
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
24 20
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
@@ -30,7 +26,7 @@ internal class LineBufferedSocketImplTest {
30 26
     }
31 27
 
32 28
     @Test
33
-    fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
29
+    fun `can send a byte array to a server`() = runBlocking {
34 30
         ServerSocket(12321).use { serverSocket ->
35 31
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
36 32
             val clientBytesAsync = GlobalScope.async {
@@ -49,7 +45,7 @@ internal class LineBufferedSocketImplTest {
49 45
     }
50 46
 
51 47
     @Test
52
-    fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
48
+    fun `can send a string to a server over TLS`() = runBlocking {
53 49
         tlsServerSocket(12321).use { serverSocket ->
54 50
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321, true)
55 51
             socket.tlsTrustManager = getTrustingManager()
@@ -69,7 +65,7 @@ internal class LineBufferedSocketImplTest {
69 65
     }
70 66
 
71 67
     @Test
72
-    fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
68
+    fun `can receive a line of CRLF delimited text`() = runBlocking {
73 69
         ServerSocket(12321).use { serverSocket ->
74 70
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
75 71
             GlobalScope.launch {
@@ -82,7 +78,7 @@ internal class LineBufferedSocketImplTest {
82 78
     }
83 79
 
84 80
     @Test
85
-    fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
81
+    fun `can receive a line of LF delimited text`() = runBlocking {
86 82
         ServerSocket(12321).use { serverSocket ->
87 83
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
88 84
             GlobalScope.launch {
@@ -95,7 +91,7 @@ internal class LineBufferedSocketImplTest {
95 91
     }
96 92
 
97 93
     @Test
98
-    fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
94
+    fun `can receive multiple lines of text in one packet`() = runBlocking {
99 95
         ServerSocket(12321).use { serverSocket ->
100 96
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
101 97
             GlobalScope.launch {
@@ -110,7 +106,7 @@ internal class LineBufferedSocketImplTest {
110 106
     }
111 107
 
112 108
     @Test
113
-    fun `KtorLineBufferedSocket can receive multiple long lines of text`() = runBlocking {
109
+    fun `can receive multiple long lines of text`() = runBlocking {
114 110
         ServerSocket(12321).use { serverSocket ->
115 111
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
116 112
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
@@ -129,7 +125,7 @@ internal class LineBufferedSocketImplTest {
129 125
     }
130 126
 
131 127
     @Test
132
-    fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
128
+    fun `can receive one line of text over multiple packets`() = runBlocking {
133 129
         ServerSocket(12321).use { serverSocket ->
134 130
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
135 131
             GlobalScope.launch {
@@ -150,7 +146,7 @@ internal class LineBufferedSocketImplTest {
150 146
     }
151 147
 
152 148
     @Test
153
-    fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
149
+    fun `returns from readLines when socket is closed`() = runBlocking {
154 150
         ServerSocket(12321).use { serverSocket ->
155 151
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
156 152
             GlobalScope.launch {
@@ -167,7 +163,7 @@ internal class LineBufferedSocketImplTest {
167 163
     }
168 164
 
169 165
     @Test
170
-    fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
166
+    fun `disconnects from server`() = runBlocking {
171 167
         ServerSocket(12321).use { serverSocket ->
172 168
             val socket = LineBufferedSocketImpl(GlobalScope, "localhost", "localhost", 12321)
173 169
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
@@ -179,25 +175,4 @@ internal class LineBufferedSocketImplTest {
179 175
         }
180 176
     }
181 177
 
182
-    private fun tlsServerSocket(port: Int): ServerSocket {
183
-        val keyStore = KeyStore.getInstance("PKCS12")
184
-        keyStore.load(LineBufferedSocketImplTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
185
-
186
-        val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
187
-        keyManagerFactory.init(keyStore, CharArray(0))
188
-
189
-        val sslContext = SSLContext.getInstance("TLSv1.2")
190
-        sslContext.init(keyManagerFactory.keyManagers, null, null)
191
-        return sslContext.serverSocketFactory.createServerSocket(port)
192
-    }
193
-
194
-    private fun getTrustingManager() = object : X509TrustManager {
195
-        override fun getAcceptedIssuers(): Array<X509Certificate>  = emptyArray()
196
-
197
-        override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
198
-
199
-        override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
200
-    }
201
-
202
-
203 178
 }

+ 86
- 2
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

@@ -2,11 +2,25 @@ package com.dmdirc.ktirc.io
2 2
 
3 3
 import io.mockk.every
4 4
 import io.mockk.mockk
5
-import org.junit.jupiter.api.Assertions.assertFalse
6
-import org.junit.jupiter.api.Assertions.assertTrue
5
+import kotlinx.coroutines.GlobalScope
6
+import kotlinx.coroutines.async
7
+import kotlinx.coroutines.io.writeFully
8
+import kotlinx.coroutines.launch
9
+import kotlinx.coroutines.runBlocking
10
+import kotlinx.io.core.String
11
+import org.junit.jupiter.api.Assertions
12
+import org.junit.jupiter.api.Assertions.*
7 13
 import org.junit.jupiter.api.Test
14
+import org.junit.jupiter.api.parallel.Execution
15
+import org.junit.jupiter.api.parallel.ExecutionMode
16
+import java.net.InetSocketAddress
17
+import java.net.ServerSocket
18
+import java.security.KeyStore
8 19
 import java.security.cert.CertificateException
9 20
 import java.security.cert.X509Certificate
21
+import javax.net.ssl.KeyManagerFactory
22
+import javax.net.ssl.SSLContext
23
+import javax.net.ssl.X509TrustManager
10 24
 
11 25
 internal class CertificateValidationTest {
12 26
 
@@ -149,4 +163,74 @@ internal class CertificateValidationTest {
149 163
         assertFalse(cert.validFor("directory.test.ktirc"))
150 164
     }
151 165
 
166
+}
167
+
168
+@Suppress("BlockingMethodInNonBlockingContext")
169
+@Execution(ExecutionMode.SAME_THREAD)
170
+internal class TlsSocketTest {
171
+
172
+    @Test
173
+    fun `can send a string to a server over TLS`() = runBlocking {
174
+        tlsServerSocket(12321).use { serverSocket ->
175
+            val plainSocket = PlainTextSocket(GlobalScope)
176
+            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
177
+            val clientBytesAsync = GlobalScope.async {
178
+                ByteArray(13).apply {
179
+                    serverSocket.accept().getInputStream().read(this)
180
+                }
181
+            }
182
+
183
+            tlsSocket.connect(InetSocketAddress("localhost", 12321))
184
+            tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
185
+
186
+            val bytes = clientBytesAsync.await()
187
+            Assertions.assertNotNull(bytes)
188
+            Assertions.assertEquals("Hello World\r\n", String(bytes))
189
+        }
190
+    }
191
+
192
+    @Test
193
+    fun `throws if the hostname mismatches`() {
194
+        tlsServerSocket(12321).use { serverSocket ->
195
+            val plainSocket = PlainTextSocket(GlobalScope)
196
+            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
197
+            GlobalScope.launch {
198
+                serverSocket.accept().getInputStream().read()
199
+            }
200
+
201
+            runBlocking {
202
+                try {
203
+                    tlsSocket.connect(InetSocketAddress("localhost", 12321))
204
+                    fail<Unit>("Expected an exception")
205
+                } catch (ex: Exception) {
206
+                    assertTrue(ex is CertificateException)
207
+                }
208
+            }
209
+        }
210
+    }
211
+
212
+
213
+}
214
+
215
+internal fun tlsServerSocket(port: Int): ServerSocket {
216
+    val keyStore = KeyStore.getInstance("PKCS12")
217
+    keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
218
+
219
+    val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
220
+    keyManagerFactory.init(keyStore, CharArray(0))
221
+
222
+    val sslContext = SSLContext.getInstance("TLSv1.2")
223
+    sslContext.init(keyManagerFactory.keyManagers, null, null)
224
+    return sslContext.serverSocketFactory.createServerSocket(port)
225
+}
226
+
227
+internal fun getTrustingContext() =
228
+        SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
229
+
230
+internal fun getTrustingManager() = object : X509TrustManager {
231
+    override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
232
+
233
+    override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
234
+
235
+    override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
152 236
 }

Loading…
Cancel
Save