Browse Source

TLS support

tags/v0.2.0
Chris Smith 5 years ago
parent
commit
14a192379b

+ 3
- 0
CHANGELOG View File

@@ -0,0 +1,3 @@
1
+v0.2.0 [in development]
2
+
3
+  * Added support for connecting over TLS

+ 1
- 0
build.gradle.kts View File

@@ -23,6 +23,7 @@ dependencies {
23 23
     implementation(kotlin("stdlib-jdk8", "1.3.20"))
24 24
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
25 25
     implementation("io.ktor:ktor-network:1.1.2")
26
+    implementation("io.ktor:ktor-network-tls:1.1.2")
26 27
 
27 28
     testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.1.0")
28 29
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.3.1")

+ 2
- 2
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt View File

@@ -33,7 +33,7 @@ interface IrcClient {
33 33
 // TODO: Should there be a default profile?
34 34
 class IrcClientImpl(private val server: Server, private val profile: Profile) : IrcClient {
35 35
 
36
-    var socketFactory: (String, Int) -> LineBufferedSocket = ::KtorLineBufferedSocket
36
+    var socketFactory: (String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
37 37
 
38 38
     override val serverState = ServerState(profile.initialNick)
39 39
     override val channelState = ChannelStateMap { caseMapping }
@@ -59,7 +59,7 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
59 59
         // TODO: Concurrency!
60 60
         check(socket == null)
61 61
         coroutineScope {
62
-            with(socketFactory(server.host, server.port)) {
62
+            with(socketFactory(server.host, server.port, server.tls)) {
63 63
                 socket = this
64 64
                 connect()
65 65
                 sendLine("CAP LS 302")

+ 12
- 2
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -6,13 +6,17 @@ import io.ktor.network.sockets.Socket
6 6
 import io.ktor.network.sockets.aSocket
7 7
 import io.ktor.network.sockets.openReadChannel
8 8
 import io.ktor.network.sockets.openWriteChannel
9
+import io.ktor.network.tls.tls
9 10
 import kotlinx.coroutines.CoroutineScope
10 11
 import kotlinx.coroutines.Dispatchers
12
+import kotlinx.coroutines.GlobalScope
11 13
 import kotlinx.coroutines.channels.ReceiveChannel
12 14
 import kotlinx.coroutines.channels.produce
13 15
 import kotlinx.coroutines.io.ByteReadChannel
14 16
 import kotlinx.coroutines.io.ByteWriteChannel
15 17
 import java.net.InetSocketAddress
18
+import java.security.SecureRandom
19
+import javax.net.ssl.X509TrustManager
16 20
 
17 21
 interface LineBufferedSocket {
18 22
 
@@ -29,14 +33,16 @@ interface LineBufferedSocket {
29 33
 /**
30 34
  * Asynchronous socket that buffers incoming data and emits individual lines.
31 35
  */
32
-// TODO: TLS options
33
-class KtorLineBufferedSocket(private val host: String, private val port: Int): LineBufferedSocket {
36
+// TODO: Expose advanced TLS options
37
+class KtorLineBufferedSocket(private val host: String, private val port: Int, private val tls: Boolean = false): LineBufferedSocket {
34 38
 
35 39
     companion object {
36 40
         const val CARRIAGE_RETURN = '\r'.toByte()
37 41
         const val LINE_FEED = '\n'.toByte()
38 42
     }
39 43
 
44
+    public var tlsTrustManager: X509TrustManager? = null
45
+
40 46
     private val log by logger()
41 47
 
42 48
     private lateinit var socket: Socket
@@ -47,6 +53,10 @@ class KtorLineBufferedSocket(private val host: String, private val port: Int): L
47 53
     override suspend fun connect() {
48 54
         log.info { "Connecting..." }
49 55
         socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
56
+        if (tls) {
57
+            // TODO: Figure out how exactly scopes work...
58
+            socket = socket.tls(GlobalScope.coroutineContext, randomAlgorithm = SecureRandom.getInstanceStrong().algorithm, trustManager = tlsTrustManager)
59
+        }
50 60
         readChannel = socket.openReadChannel()
51 61
         writeChannel = socket.openWriteChannel()
52 62
     }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/model/Server.kt View File

@@ -1,3 +1,3 @@
1 1
 package com.dmdirc.ktirc.model
2 2
 
3
-data class Server(val host: String, val port: Int, val ssl: Boolean = false, val password: String? = null)
3
+data class Server(val host: String, val port: Int, val tls: Boolean = false, val password: String? = null)

+ 16
- 3
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt View File

@@ -31,8 +31,8 @@ internal class IrcClientImplTest {
31 31
         on { readLines(any()) } doReturn readLineChannel
32 32
     }
33 33
 
34
-    private val mockSocketFactory = mock<(String, Int) -> LineBufferedSocket> {
35
-        on { invoke(HOST, PORT) } doReturn mockSocket
34
+    private val mockSocketFactory = mock<(String, Int, Boolean) -> LineBufferedSocket> {
35
+        on { invoke(eq(HOST), eq(PORT), any()) } doReturn mockSocket
36 36
     }
37 37
 
38 38
     private val mockEventHandler = mock<EventHandler>()
@@ -51,7 +51,20 @@ internal class IrcClientImplTest {
51 51
 
52 52
             client.connect()
53 53
 
54
-            verify(mockSocketFactory).invoke(HOST, PORT)
54
+            verify(mockSocketFactory).invoke(HOST, PORT, false)
55
+        }
56
+    }
57
+
58
+    @Test
59
+    fun `IrcClientImpl uses socket factory to create a new tls on connect`() {
60
+        runBlocking {
61
+            val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
62
+            client.socketFactory = mockSocketFactory
63
+            readLineChannel.close()
64
+
65
+            client.connect()
66
+
67
+            verify(mockSocketFactory).invoke(HOST, PORT, true)
55 68
         }
56 69
     }
57 70
 

+ 69
- 2
src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt View File

@@ -1,15 +1,22 @@
1 1
 package com.dmdirc.ktirc.io
2 2
 
3
+import io.ktor.network.tls.certificates.generateCertificate
3 4
 import kotlinx.coroutines.GlobalScope
4 5
 import kotlinx.coroutines.async
5 6
 import kotlinx.coroutines.launch
6 7
 import kotlinx.coroutines.runBlocking
7
-import org.junit.jupiter.api.Assertions.assertEquals
8
-import org.junit.jupiter.api.Assertions.assertNotNull
8
+import org.junit.jupiter.api.Assertions.*
9 9
 import org.junit.jupiter.api.Test
10 10
 import org.junit.jupiter.api.parallel.Execution
11 11
 import org.junit.jupiter.api.parallel.ExecutionMode
12
+import sun.security.validator.ValidatorException
13
+import java.io.File
12 14
 import java.net.ServerSocket
15
+import java.security.KeyStore
16
+import java.security.cert.X509Certificate
17
+import javax.net.ssl.KeyManagerFactory
18
+import javax.net.ssl.SSLContext
19
+import javax.net.ssl.X509TrustManager
13 20
 
14 21
 @Execution(ExecutionMode.SAME_THREAD)
15 22
 internal class KtorLineBufferedSocketTest {
@@ -26,6 +33,22 @@ internal class KtorLineBufferedSocketTest {
26 33
         }
27 34
     }
28 35
 
36
+    @Test
37
+    fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking {
38
+        tlsServerSocket(12321).use { serverSocket ->
39
+            try {
40
+                val socket = KtorLineBufferedSocket("localhost", 12321)
41
+                val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
42
+
43
+                socket.connect()
44
+                assertNotNull(clientSocketAsync.await())
45
+                fail<Unit>()
46
+            } catch (ex : ValidatorException) {
47
+                // Expected
48
+            }
49
+        }
50
+    }
51
+
29 52
     @Test
30 53
     fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
31 54
         ServerSocket(12321).use { serverSocket ->
@@ -64,6 +87,26 @@ internal class KtorLineBufferedSocketTest {
64 87
         }
65 88
     }
66 89
 
90
+    @Test
91
+    fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
92
+        tlsServerSocket(12321).use { serverSocket ->
93
+            val socket = KtorLineBufferedSocket("localhost", 12321, true)
94
+            socket.tlsTrustManager = getTrustingManager()
95
+            val clientBytesAsync = GlobalScope.async {
96
+                ByteArray(13).apply {
97
+                    serverSocket.accept().getInputStream().read(this)
98
+                }
99
+            }
100
+
101
+            socket.connect()
102
+            socket.sendLine("Hello World")
103
+
104
+            val bytes = clientBytesAsync.await()
105
+            assertNotNull(bytes)
106
+            assertEquals("Hello World\r\n", String(bytes))
107
+        }
108
+    }
109
+
67 110
     @Test
68 111
     fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
69 112
         ServerSocket(12321).use { serverSocket ->
@@ -175,4 +218,28 @@ internal class KtorLineBufferedSocketTest {
175 218
         }
176 219
     }
177 220
 
221
+    private fun tlsServerSocket(port: Int): ServerSocket {
222
+        val keyFile = File.createTempFile("selfsigned", "jks")
223
+        generateCertificate(keyFile)
224
+
225
+        val keyStore = KeyStore.getInstance("JKS")
226
+        keyStore.load(keyFile.inputStream(), "changeit".toCharArray())
227
+
228
+        val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
229
+        keyManagerFactory.init(keyStore, "changeit".toCharArray())
230
+
231
+        val sslContext = SSLContext.getInstance("TLSv1.2")
232
+        sslContext.init(keyManagerFactory.keyManagers, null, null)
233
+        return sslContext.serverSocketFactory.createServerSocket(port)
234
+    }
235
+
236
+    private fun getTrustingManager() = object : X509TrustManager {
237
+        override fun getAcceptedIssuers(): Array<X509Certificate>  = emptyArray()
238
+
239
+        override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
240
+
241
+        override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
242
+    }
243
+
244
+
178 245
 }

Loading…
Cancel
Save