浏览代码

TLS support

tags/v0.2.0
Chris Smith 5 年前
父节点
当前提交
14a192379b

+ 3
- 0
CHANGELOG 查看文件

1
+v0.2.0 [in development]
2
+
3
+  * Added support for connecting over TLS

+ 1
- 0
build.gradle.kts 查看文件

23
     implementation(kotlin("stdlib-jdk8", "1.3.20"))
23
     implementation(kotlin("stdlib-jdk8", "1.3.20"))
24
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
24
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
25
     implementation("io.ktor:ktor-network:1.1.2")
25
     implementation("io.ktor:ktor-network:1.1.2")
26
+    implementation("io.ktor:ktor-network-tls:1.1.2")
26
 
27
 
27
     testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.1.0")
28
     testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.1.0")
28
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.3.1")
29
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.3.1")

+ 2
- 2
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt 查看文件

33
 // TODO: Should there be a default profile?
33
 // TODO: Should there be a default profile?
34
 class IrcClientImpl(private val server: Server, private val profile: Profile) : IrcClient {
34
 class IrcClientImpl(private val server: Server, private val profile: Profile) : IrcClient {
35
 
35
 
36
-    var socketFactory: (String, Int) -> LineBufferedSocket = ::KtorLineBufferedSocket
36
+    var socketFactory: (String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
37
 
37
 
38
     override val serverState = ServerState(profile.initialNick)
38
     override val serverState = ServerState(profile.initialNick)
39
     override val channelState = ChannelStateMap { caseMapping }
39
     override val channelState = ChannelStateMap { caseMapping }
59
         // TODO: Concurrency!
59
         // TODO: Concurrency!
60
         check(socket == null)
60
         check(socket == null)
61
         coroutineScope {
61
         coroutineScope {
62
-            with(socketFactory(server.host, server.port)) {
62
+            with(socketFactory(server.host, server.port, server.tls)) {
63
                 socket = this
63
                 socket = this
64
                 connect()
64
                 connect()
65
                 sendLine("CAP LS 302")
65
                 sendLine("CAP LS 302")

+ 12
- 2
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt 查看文件

6
 import io.ktor.network.sockets.aSocket
6
 import io.ktor.network.sockets.aSocket
7
 import io.ktor.network.sockets.openReadChannel
7
 import io.ktor.network.sockets.openReadChannel
8
 import io.ktor.network.sockets.openWriteChannel
8
 import io.ktor.network.sockets.openWriteChannel
9
+import io.ktor.network.tls.tls
9
 import kotlinx.coroutines.CoroutineScope
10
 import kotlinx.coroutines.CoroutineScope
10
 import kotlinx.coroutines.Dispatchers
11
 import kotlinx.coroutines.Dispatchers
12
+import kotlinx.coroutines.GlobalScope
11
 import kotlinx.coroutines.channels.ReceiveChannel
13
 import kotlinx.coroutines.channels.ReceiveChannel
12
 import kotlinx.coroutines.channels.produce
14
 import kotlinx.coroutines.channels.produce
13
 import kotlinx.coroutines.io.ByteReadChannel
15
 import kotlinx.coroutines.io.ByteReadChannel
14
 import kotlinx.coroutines.io.ByteWriteChannel
16
 import kotlinx.coroutines.io.ByteWriteChannel
15
 import java.net.InetSocketAddress
17
 import java.net.InetSocketAddress
18
+import java.security.SecureRandom
19
+import javax.net.ssl.X509TrustManager
16
 
20
 
17
 interface LineBufferedSocket {
21
 interface LineBufferedSocket {
18
 
22
 
29
 /**
33
 /**
30
  * Asynchronous socket that buffers incoming data and emits individual lines.
34
  * Asynchronous socket that buffers incoming data and emits individual lines.
31
  */
35
  */
32
-// TODO: TLS options
33
-class KtorLineBufferedSocket(private val host: String, private val port: Int): LineBufferedSocket {
36
+// TODO: Expose advanced TLS options
37
+class KtorLineBufferedSocket(private val host: String, private val port: Int, private val tls: Boolean = false): LineBufferedSocket {
34
 
38
 
35
     companion object {
39
     companion object {
36
         const val CARRIAGE_RETURN = '\r'.toByte()
40
         const val CARRIAGE_RETURN = '\r'.toByte()
37
         const val LINE_FEED = '\n'.toByte()
41
         const val LINE_FEED = '\n'.toByte()
38
     }
42
     }
39
 
43
 
44
+    public var tlsTrustManager: X509TrustManager? = null
45
+
40
     private val log by logger()
46
     private val log by logger()
41
 
47
 
42
     private lateinit var socket: Socket
48
     private lateinit var socket: Socket
47
     override suspend fun connect() {
53
     override suspend fun connect() {
48
         log.info { "Connecting..." }
54
         log.info { "Connecting..." }
49
         socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
55
         socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
56
+        if (tls) {
57
+            // TODO: Figure out how exactly scopes work...
58
+            socket = socket.tls(GlobalScope.coroutineContext, randomAlgorithm = SecureRandom.getInstanceStrong().algorithm, trustManager = tlsTrustManager)
59
+        }
50
         readChannel = socket.openReadChannel()
60
         readChannel = socket.openReadChannel()
51
         writeChannel = socket.openWriteChannel()
61
         writeChannel = socket.openWriteChannel()
52
     }
62
     }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/model/Server.kt 查看文件

1
 package com.dmdirc.ktirc.model
1
 package com.dmdirc.ktirc.model
2
 
2
 
3
-data class Server(val host: String, val port: Int, val ssl: Boolean = false, val password: String? = null)
3
+data class Server(val host: String, val port: Int, val tls: Boolean = false, val password: String? = null)

+ 16
- 3
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt 查看文件

31
         on { readLines(any()) } doReturn readLineChannel
31
         on { readLines(any()) } doReturn readLineChannel
32
     }
32
     }
33
 
33
 
34
-    private val mockSocketFactory = mock<(String, Int) -> LineBufferedSocket> {
35
-        on { invoke(HOST, PORT) } doReturn mockSocket
34
+    private val mockSocketFactory = mock<(String, Int, Boolean) -> LineBufferedSocket> {
35
+        on { invoke(eq(HOST), eq(PORT), any()) } doReturn mockSocket
36
     }
36
     }
37
 
37
 
38
     private val mockEventHandler = mock<EventHandler>()
38
     private val mockEventHandler = mock<EventHandler>()
51
 
51
 
52
             client.connect()
52
             client.connect()
53
 
53
 
54
-            verify(mockSocketFactory).invoke(HOST, PORT)
54
+            verify(mockSocketFactory).invoke(HOST, PORT, false)
55
+        }
56
+    }
57
+
58
+    @Test
59
+    fun `IrcClientImpl uses socket factory to create a new tls on connect`() {
60
+        runBlocking {
61
+            val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
62
+            client.socketFactory = mockSocketFactory
63
+            readLineChannel.close()
64
+
65
+            client.connect()
66
+
67
+            verify(mockSocketFactory).invoke(HOST, PORT, true)
55
         }
68
         }
56
     }
69
     }
57
 
70
 

+ 69
- 2
src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt 查看文件

1
 package com.dmdirc.ktirc.io
1
 package com.dmdirc.ktirc.io
2
 
2
 
3
+import io.ktor.network.tls.certificates.generateCertificate
3
 import kotlinx.coroutines.GlobalScope
4
 import kotlinx.coroutines.GlobalScope
4
 import kotlinx.coroutines.async
5
 import kotlinx.coroutines.async
5
 import kotlinx.coroutines.launch
6
 import kotlinx.coroutines.launch
6
 import kotlinx.coroutines.runBlocking
7
 import kotlinx.coroutines.runBlocking
7
-import org.junit.jupiter.api.Assertions.assertEquals
8
-import org.junit.jupiter.api.Assertions.assertNotNull
8
+import org.junit.jupiter.api.Assertions.*
9
 import org.junit.jupiter.api.Test
9
 import org.junit.jupiter.api.Test
10
 import org.junit.jupiter.api.parallel.Execution
10
 import org.junit.jupiter.api.parallel.Execution
11
 import org.junit.jupiter.api.parallel.ExecutionMode
11
 import org.junit.jupiter.api.parallel.ExecutionMode
12
+import sun.security.validator.ValidatorException
13
+import java.io.File
12
 import java.net.ServerSocket
14
 import java.net.ServerSocket
15
+import java.security.KeyStore
16
+import java.security.cert.X509Certificate
17
+import javax.net.ssl.KeyManagerFactory
18
+import javax.net.ssl.SSLContext
19
+import javax.net.ssl.X509TrustManager
13
 
20
 
14
 @Execution(ExecutionMode.SAME_THREAD)
21
 @Execution(ExecutionMode.SAME_THREAD)
15
 internal class KtorLineBufferedSocketTest {
22
 internal class KtorLineBufferedSocketTest {
26
         }
33
         }
27
     }
34
     }
28
 
35
 
36
+    @Test
37
+    fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking {
38
+        tlsServerSocket(12321).use { serverSocket ->
39
+            try {
40
+                val socket = KtorLineBufferedSocket("localhost", 12321)
41
+                val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
42
+
43
+                socket.connect()
44
+                assertNotNull(clientSocketAsync.await())
45
+                fail<Unit>()
46
+            } catch (ex : ValidatorException) {
47
+                // Expected
48
+            }
49
+        }
50
+    }
51
+
29
     @Test
52
     @Test
30
     fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
53
     fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
31
         ServerSocket(12321).use { serverSocket ->
54
         ServerSocket(12321).use { serverSocket ->
64
         }
87
         }
65
     }
88
     }
66
 
89
 
90
+    @Test
91
+    fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
92
+        tlsServerSocket(12321).use { serverSocket ->
93
+            val socket = KtorLineBufferedSocket("localhost", 12321, true)
94
+            socket.tlsTrustManager = getTrustingManager()
95
+            val clientBytesAsync = GlobalScope.async {
96
+                ByteArray(13).apply {
97
+                    serverSocket.accept().getInputStream().read(this)
98
+                }
99
+            }
100
+
101
+            socket.connect()
102
+            socket.sendLine("Hello World")
103
+
104
+            val bytes = clientBytesAsync.await()
105
+            assertNotNull(bytes)
106
+            assertEquals("Hello World\r\n", String(bytes))
107
+        }
108
+    }
109
+
67
     @Test
110
     @Test
68
     fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
111
     fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
69
         ServerSocket(12321).use { serverSocket ->
112
         ServerSocket(12321).use { serverSocket ->
175
         }
218
         }
176
     }
219
     }
177
 
220
 
221
+    private fun tlsServerSocket(port: Int): ServerSocket {
222
+        val keyFile = File.createTempFile("selfsigned", "jks")
223
+        generateCertificate(keyFile)
224
+
225
+        val keyStore = KeyStore.getInstance("JKS")
226
+        keyStore.load(keyFile.inputStream(), "changeit".toCharArray())
227
+
228
+        val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
229
+        keyManagerFactory.init(keyStore, "changeit".toCharArray())
230
+
231
+        val sslContext = SSLContext.getInstance("TLSv1.2")
232
+        sslContext.init(keyManagerFactory.keyManagers, null, null)
233
+        return sslContext.serverSocketFactory.createServerSocket(port)
234
+    }
235
+
236
+    private fun getTrustingManager() = object : X509TrustManager {
237
+        override fun getAcceptedIssuers(): Array<X509Certificate>  = emptyArray()
238
+
239
+        override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
240
+
241
+        override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
242
+    }
243
+
244
+
178
 }
245
 }

正在加载...
取消
保存