|
@@ -1,15 +1,22 @@
|
1
|
1
|
package com.dmdirc.ktirc.io
|
2
|
2
|
|
|
3
|
+import io.ktor.network.tls.certificates.generateCertificate
|
3
|
4
|
import kotlinx.coroutines.GlobalScope
|
4
|
5
|
import kotlinx.coroutines.async
|
5
|
6
|
import kotlinx.coroutines.launch
|
6
|
7
|
import kotlinx.coroutines.runBlocking
|
7
|
|
-import org.junit.jupiter.api.Assertions.assertEquals
|
8
|
|
-import org.junit.jupiter.api.Assertions.assertNotNull
|
|
8
|
+import org.junit.jupiter.api.Assertions.*
|
9
|
9
|
import org.junit.jupiter.api.Test
|
10
|
10
|
import org.junit.jupiter.api.parallel.Execution
|
11
|
11
|
import org.junit.jupiter.api.parallel.ExecutionMode
|
|
12
|
+import sun.security.validator.ValidatorException
|
|
13
|
+import java.io.File
|
12
|
14
|
import java.net.ServerSocket
|
|
15
|
+import java.security.KeyStore
|
|
16
|
+import java.security.cert.X509Certificate
|
|
17
|
+import javax.net.ssl.KeyManagerFactory
|
|
18
|
+import javax.net.ssl.SSLContext
|
|
19
|
+import javax.net.ssl.X509TrustManager
|
13
|
20
|
|
14
|
21
|
@Execution(ExecutionMode.SAME_THREAD)
|
15
|
22
|
internal class KtorLineBufferedSocketTest {
|
|
@@ -26,6 +33,22 @@ internal class KtorLineBufferedSocketTest {
|
26
|
33
|
}
|
27
|
34
|
}
|
28
|
35
|
|
|
36
|
+ @Test
|
|
37
|
+ fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking {
|
|
38
|
+ tlsServerSocket(12321).use { serverSocket ->
|
|
39
|
+ try {
|
|
40
|
+ val socket = KtorLineBufferedSocket("localhost", 12321)
|
|
41
|
+ val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
|
|
42
|
+
|
|
43
|
+ socket.connect()
|
|
44
|
+ assertNotNull(clientSocketAsync.await())
|
|
45
|
+ fail<Unit>()
|
|
46
|
+ } catch (ex : ValidatorException) {
|
|
47
|
+ // Expected
|
|
48
|
+ }
|
|
49
|
+ }
|
|
50
|
+ }
|
|
51
|
+
|
29
|
52
|
@Test
|
30
|
53
|
fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
|
31
|
54
|
ServerSocket(12321).use { serverSocket ->
|
|
@@ -64,6 +87,26 @@ internal class KtorLineBufferedSocketTest {
|
64
|
87
|
}
|
65
|
88
|
}
|
66
|
89
|
|
|
90
|
+ @Test
|
|
91
|
+ fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
|
|
92
|
+ tlsServerSocket(12321).use { serverSocket ->
|
|
93
|
+ val socket = KtorLineBufferedSocket("localhost", 12321, true)
|
|
94
|
+ socket.tlsTrustManager = getTrustingManager()
|
|
95
|
+ val clientBytesAsync = GlobalScope.async {
|
|
96
|
+ ByteArray(13).apply {
|
|
97
|
+ serverSocket.accept().getInputStream().read(this)
|
|
98
|
+ }
|
|
99
|
+ }
|
|
100
|
+
|
|
101
|
+ socket.connect()
|
|
102
|
+ socket.sendLine("Hello World")
|
|
103
|
+
|
|
104
|
+ val bytes = clientBytesAsync.await()
|
|
105
|
+ assertNotNull(bytes)
|
|
106
|
+ assertEquals("Hello World\r\n", String(bytes))
|
|
107
|
+ }
|
|
108
|
+ }
|
|
109
|
+
|
67
|
110
|
@Test
|
68
|
111
|
fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
|
69
|
112
|
ServerSocket(12321).use { serverSocket ->
|
|
@@ -175,4 +218,28 @@ internal class KtorLineBufferedSocketTest {
|
175
|
218
|
}
|
176
|
219
|
}
|
177
|
220
|
|
|
221
|
+ private fun tlsServerSocket(port: Int): ServerSocket {
|
|
222
|
+ val keyFile = File.createTempFile("selfsigned", "jks")
|
|
223
|
+ generateCertificate(keyFile)
|
|
224
|
+
|
|
225
|
+ val keyStore = KeyStore.getInstance("JKS")
|
|
226
|
+ keyStore.load(keyFile.inputStream(), "changeit".toCharArray())
|
|
227
|
+
|
|
228
|
+ val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
|
|
229
|
+ keyManagerFactory.init(keyStore, "changeit".toCharArray())
|
|
230
|
+
|
|
231
|
+ val sslContext = SSLContext.getInstance("TLSv1.2")
|
|
232
|
+ sslContext.init(keyManagerFactory.keyManagers, null, null)
|
|
233
|
+ return sslContext.serverSocketFactory.createServerSocket(port)
|
|
234
|
+ }
|
|
235
|
+
|
|
236
|
+ private fun getTrustingManager() = object : X509TrustManager {
|
|
237
|
+ override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
|
|
238
|
+
|
|
239
|
+ override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
|
|
240
|
+
|
|
241
|
+ override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
|
|
242
|
+ }
|
|
243
|
+
|
|
244
|
+
|
178
|
245
|
}
|