Browse Source

Replace ktor with plain socket code.

Closes #14
tags/v1.0.0
Chris Smith 5 years ago
parent
commit
96449f98a1

+ 3
- 0
CHANGELOG View File

@@ -1,5 +1,8 @@
1 1
 vNEXT (in development)
2 2
 
3
+ * Replaced Ktor dependency with custom socket handling, which fixes
4
+   fatal issue when connecting to servers over TLS that request a
5
+   client certificate.
3 6
  * Added NicknameChangeRequired event for the case when a nickname is
4 7
    not allowed during connection and *MUST* be changed
5 8
 

+ 0
- 9
README.md View File

@@ -77,15 +77,6 @@ client.connect()
77 77
 
78 78
 ## Known issues / FAQ
79 79
 
80
-### `java.lang.IllegalStateException: Check failed` when connecting to some servers
81
-
82
-This happens when the IRC server requests an optional client certificate (for use
83
-in SASL EXTERNAL auth, usually). At present there is no support for client
84
-certificates in the networking library used by KtIrc. This is fixed in the
85
-[upstream library](https://github.com/ktorio/ktor/issues/641) and will be included
86
-as soon as snapshot builds are available. There is no workaround other than using
87
-an insecure connection.
88
-
89 80
 ### KtIrc connects over IPv4 even when host has IPv6
90 81
 
91 82
 This is an issue with the Java standard library. You can change its behaviour by

+ 2
- 2
build.gradle.kts View File

@@ -39,8 +39,8 @@ repositories {
39 39
 dependencies {
40 40
     implementation(kotlin("stdlib-jdk8", "1.3.21"))
41 41
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
42
-    implementation("io.ktor:ktor-network:1.1.3")
43
-    implementation("io.ktor:ktor-network-tls:1.1.3")
42
+    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-io-jvm:0.1.7")
43
+    compile(kotlin("reflect"))
44 44
 
45 45
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.4.0")
46 46
     testImplementation("org.junit.jupiter:junit-jupiter-params:5.4.0")

+ 1
- 1
docs/index.adoc View File

@@ -1099,7 +1099,7 @@ both events will need to be handled separately.
1099 1099
 
1100 1100
 TODO
1101 1101
 
1102
-==== sendNickChange
1102
+=== sendNickChange
1103 1103
 
1104 1104
 TODO
1105 1105
 

+ 1
- 1
src/itest/kotlin/com/dmdirc/irctest/cases/connection/capabilities/v32/ServerTime.kt View File

@@ -6,7 +6,7 @@ val serverTime = testCase("connection.capabilities.302.server-time") {
6 6
     steps {
7 7
         expect("CAP LS 302")
8 8
         send("CAP * LS :server-time")
9
-        expect("CAP REQ :server-time")
9
+        expect("CAP REQ server-time")
10 10
         send("CAP * ACK :server-time")
11 11
         expect("CAP END")
12 12
     }

+ 3
- 6
src/main/kotlin/com/dmdirc/ktirc/IrcClientImpl.kt View File

@@ -3,8 +3,8 @@ package com.dmdirc.ktirc
3 3
 import com.dmdirc.ktirc.events.*
4 4
 import com.dmdirc.ktirc.events.handlers.eventHandlers
5 5
 import com.dmdirc.ktirc.events.mutators.eventMutators
6
-import com.dmdirc.ktirc.io.KtorLineBufferedSocket
7 6
 import com.dmdirc.ktirc.io.LineBufferedSocket
7
+import com.dmdirc.ktirc.io.LineBufferedSocketImpl
8 8
 import com.dmdirc.ktirc.io.MessageHandler
9 9
 import com.dmdirc.ktirc.io.MessageParser
10 10
 import com.dmdirc.ktirc.messages.*
@@ -13,11 +13,9 @@ import com.dmdirc.ktirc.model.*
13 13
 import com.dmdirc.ktirc.util.currentTimeProvider
14 14
 import com.dmdirc.ktirc.util.generateLabel
15 15
 import com.dmdirc.ktirc.util.logger
16
-import io.ktor.util.KtorExperimentalAPI
17 16
 import kotlinx.coroutines.*
18 17
 import kotlinx.coroutines.channels.Channel
19 18
 import kotlinx.coroutines.channels.map
20
-import kotlinx.coroutines.time.withTimeoutOrNull
21 19
 import java.time.Duration
22 20
 import java.util.concurrent.atomic.AtomicBoolean
23 21
 import java.util.logging.Level
@@ -35,8 +33,7 @@ internal class IrcClientImpl(private val config: IrcClientConfig) : Experimental
35 33
     override val coroutineContext = GlobalScope.newCoroutineContext(Dispatchers.IO)
36 34
 
37 35
     @ExperimentalCoroutinesApi
38
-    @KtorExperimentalAPI
39
-    internal var socketFactory: (CoroutineScope, String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
36
+    internal var socketFactory: (CoroutineScope, String, Int, Boolean) -> LineBufferedSocket = ::LineBufferedSocketImpl
40 37
 
41 38
     internal var asyncTimeout = Duration.ofSeconds(20)
42 39
 
@@ -76,7 +73,7 @@ internal class IrcClientImpl(private val config: IrcClientConfig) : Experimental
76 73
             send(tags, command, *arguments)
77 74
         }
78 75
 
79
-        withTimeoutOrNull(asyncTimeout) {
76
+        withTimeoutOrNull(asyncTimeout.toMillis()) {
80 77
             channel.receive()
81 78
         }.also { serverState.asyncResponseState.pendingResponses.remove(label) }
82 79
     }

+ 15
- 21
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -1,23 +1,17 @@
1 1
 package com.dmdirc.ktirc.io
2 2
 
3 3
 import com.dmdirc.ktirc.util.logger
4
-import io.ktor.network.selector.ActorSelectorManager
5
-import io.ktor.network.sockets.Socket
6
-import io.ktor.network.sockets.aSocket
7
-import io.ktor.network.sockets.openReadChannel
8
-import io.ktor.network.sockets.openWriteChannel
9
-import io.ktor.network.tls.tls
10
-import io.ktor.util.KtorExperimentalAPI
11 4
 import kotlinx.coroutines.*
12 5
 import kotlinx.coroutines.channels.Channel
13 6
 import kotlinx.coroutines.channels.ReceiveChannel
14 7
 import kotlinx.coroutines.channels.SendChannel
15 8
 import kotlinx.coroutines.channels.produce
16
-import kotlinx.coroutines.io.ByteReadChannel
17 9
 import kotlinx.coroutines.io.ByteWriteChannel
18 10
 import java.net.InetSocketAddress
11
+import java.nio.ByteBuffer
19 12
 import java.security.SecureRandom
20 13
 import java.security.cert.CertificateException
14
+import javax.net.ssl.SSLContext
21 15
 import javax.net.ssl.X509TrustManager
22 16
 
23 17
 internal interface LineBufferedSocket {
@@ -36,9 +30,8 @@ internal interface LineBufferedSocket {
36 30
  * Asynchronous socket that buffers incoming data and emits individual lines.
37 31
  */
38 32
 // TODO: Expose advanced TLS options
39
-@KtorExperimentalAPI
40 33
 @ExperimentalCoroutinesApi
41
-internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
34
+internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
42 35
 
43 36
     companion object {
44 37
         const val CARRIAGE_RETURN = '\r'.toByte()
@@ -53,21 +46,22 @@ internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private va
53 46
     private val log by logger()
54 47
 
55 48
     private lateinit var socket: Socket
56
-    private lateinit var readChannel: ByteReadChannel
57 49
     private lateinit var writeChannel: ByteWriteChannel
58 50
 
59 51
     override fun connect() {
52
+        log.info { "Connecting..." }
53
+        socket = PlainTextSocket(this)
54
+
60 55
         runBlocking {
61
-            log.info { "Connecting..." }
62
-            socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
63 56
             if (tls) {
64
-                socket = socket.tls(
65
-                        coroutineContext = this@KtorLineBufferedSocket.coroutineContext,
66
-                        randomAlgorithm = SecureRandom.getInstanceStrong().algorithm,
67
-                        trustManager = tlsTrustManager)
57
+                with (SSLContext.getInstance("TLSv1.2")) {
58
+                    init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
59
+                    socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
60
+                }
68 61
             }
69
-            readChannel = socket.openReadChannel()
70
-            writeChannel = socket.openWriteChannel()
62
+            socket.connect(InetSocketAddress(host, port))
63
+            println("connected!")
64
+            writeChannel = socket.write
71 65
         }
72 66
         launch { writeLines() }
73 67
     }
@@ -82,9 +76,9 @@ internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private va
82 76
         get() = produce {
83 77
             val lineBuffer = ByteArray(16384)
84 78
             var nextByteOffset = 0
85
-            while (!readChannel.isClosedForRead) {
79
+            while (socket.isOpen) {
86 80
                 var lineStart = 0
87
-                val bytesRead = readChannel.readAvailable(lineBuffer, nextByteOffset, lineBuffer.size - nextByteOffset)
81
+                val bytesRead = socket.read(ByteBuffer.wrap(lineBuffer).apply { position(nextByteOffset) })
88 82
                 for (i in nextByteOffset until nextByteOffset + bytesRead) {
89 83
                     if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
90 84
                         if (lineStart < i) {

+ 148
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

@@ -0,0 +1,148 @@
1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.coroutines.CancellableContinuation
4
+import kotlinx.coroutines.CoroutineScope
5
+import kotlinx.coroutines.io.ByteChannel
6
+import kotlinx.coroutines.io.ByteWriteChannel
7
+import kotlinx.coroutines.io.close
8
+import kotlinx.coroutines.launch
9
+import kotlinx.coroutines.suspendCancellableCoroutine
10
+import kotlinx.io.pool.DefaultPool
11
+import java.net.SocketAddress
12
+import java.nio.ByteBuffer
13
+import java.nio.channels.*
14
+import kotlin.coroutines.resume
15
+import kotlin.coroutines.resumeWithException
16
+
17
+internal const val BUFFER_SIZE = 32768
18
+internal const val POOL_SIZE = 16
19
+
20
+internal val defaultPool = ByteBufferPool()
21
+
22
+internal class ByteBufferPool : DefaultPool<ByteBuffer>(POOL_SIZE) {
23
+    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(BUFFER_SIZE)
24
+    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
25
+
26
+    inline fun <T> borrow(block: (ByteBuffer) -> T): T {
27
+        val buffer = borrow()
28
+        try {
29
+            return block(buffer)
30
+        } finally {
31
+            recycle(buffer)
32
+        }
33
+    }
34
+}
35
+
36
+internal interface Socket {
37
+    fun bind(socketAddress: SocketAddress)
38
+    suspend fun connect(socketAddress: SocketAddress)
39
+    suspend fun read(buffer: ByteBuffer): Int
40
+    fun close()
41
+    val write: ByteWriteChannel
42
+    val isOpen: Boolean
43
+}
44
+
45
+internal class PlainTextSocket(private val scope: CoroutineScope) : Socket {
46
+
47
+    private val client = AsynchronousSocketChannel.open()
48
+    private var writeChannel = ByteChannel(autoFlush = true)
49
+
50
+    override val write: ByteWriteChannel
51
+        get() = writeChannel
52
+
53
+    override val isOpen: Boolean
54
+        get() = client.isOpen
55
+
56
+    override fun bind(socketAddress: SocketAddress) {
57
+        client.bind(socketAddress)
58
+    }
59
+
60
+    override suspend fun connect(socketAddress: SocketAddress) {
61
+        writeChannel = ByteChannel(autoFlush = true)
62
+
63
+        suspendCancellableCoroutine<Unit> { continuation ->
64
+            client.closeOnCancel(continuation)
65
+            client.connect(socketAddress, continuation, AsyncVoidIOHandler)
66
+        }
67
+
68
+        scope.launch { writeLoop() }
69
+    }
70
+
71
+    override fun close() {
72
+        writeChannel.close()
73
+        client.close()
74
+    }
75
+
76
+    override suspend fun read(buffer: ByteBuffer) = try {
77
+        val bytes = suspendCancellableCoroutine<Int> { continuation ->
78
+
79
+            client.closeOnCancel(continuation)
80
+            client.read(buffer, continuation, asyncIOHandler())
81
+        }
82
+
83
+        if (bytes == -1) {
84
+            close()
85
+        }
86
+        bytes
87
+    } catch (_: ClosedChannelException) {
88
+        // Ignore
89
+        0
90
+    }
91
+
92
+    private suspend fun writeLoop() {
93
+        while (client.isOpen) {
94
+            defaultPool.borrow { buffer ->
95
+                writeChannel.readAvailable(buffer)
96
+                buffer.flip()
97
+                try {
98
+                    suspendCancellableCoroutine<Int> { continuation ->
99
+                        client.closeOnCancel(continuation)
100
+                        client.write(buffer, continuation, asyncIOHandler())
101
+                    }
102
+                } catch (_: ClosedChannelException) {
103
+                    // Ignore
104
+                }
105
+            }
106
+        }
107
+    }
108
+
109
+}
110
+
111
+private fun Channel.closeOnCancel(cont: CancellableContinuation<*>) {
112
+    cont.invokeOnCancellation {
113
+        try {
114
+            close()
115
+        } catch (ex: Throwable) {
116
+            // Specification says that it is Ok to call it any time, but reality is different,
117
+            // so we have just to ignore exception
118
+        }
119
+    }
120
+}
121
+
122
+@Suppress("UNCHECKED_CAST")
123
+private fun <T> asyncIOHandler(): CompletionHandler<T, CancellableContinuation<T>> =
124
+        AsyncIOHandlerAny as CompletionHandler<T, CancellableContinuation<T>>
125
+
126
+private object AsyncIOHandlerAny : CompletionHandler<Any, CancellableContinuation<Any>> {
127
+    override fun completed(result: Any, cont: CancellableContinuation<Any>) {
128
+        cont.resume(result)
129
+    }
130
+
131
+    override fun failed(ex: Throwable, cont: CancellableContinuation<Any>) {
132
+        // just return if already cancelled and got an expected exception for that case
133
+        if (ex is AsynchronousCloseException && cont.isCancelled) return
134
+        cont.resumeWithException(ex)
135
+    }
136
+}
137
+
138
+private object AsyncVoidIOHandler : CompletionHandler<Void?, CancellableContinuation<Unit>> {
139
+    override fun completed(result: Void?, cont: CancellableContinuation<Unit>) {
140
+        cont.resume(Unit)
141
+    }
142
+
143
+    override fun failed(ex: Throwable, cont: CancellableContinuation<Unit>) {
144
+        // just return if already cancelled and got an expected exception for that case
145
+        if (ex is AsynchronousCloseException && cont.isCancelled) return
146
+        cont.resumeWithException(ex)
147
+    }
148
+}

+ 196
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Tls.kt View File

@@ -0,0 +1,196 @@
1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.coroutines.CoroutineScope
4
+import kotlinx.coroutines.io.ByteChannel
5
+import kotlinx.coroutines.io.ByteWriteChannel
6
+import kotlinx.coroutines.launch
7
+import kotlinx.coroutines.sync.Mutex
8
+import kotlinx.coroutines.sync.withLock
9
+import java.net.SocketAddress
10
+import java.nio.ByteBuffer
11
+import java.security.cert.CertificateException
12
+import java.security.cert.X509Certificate
13
+import java.util.regex.Pattern
14
+import javax.naming.ldap.LdapName
15
+import javax.naming.ldap.Rdn
16
+import javax.net.ssl.SSLContext
17
+import javax.net.ssl.SSLEngine
18
+import javax.net.ssl.SSLEngineResult
19
+
20
+
21
+internal class TlsSocket(
22
+        private val scope: CoroutineScope,
23
+        private val socket: Socket,
24
+        private val sslContext: SSLContext,
25
+        private val hostname: String
26
+) : Socket {
27
+
28
+    private var engine: SSLEngine = sslContext.createSSLEngine()
29
+
30
+    private var incomingNetBuffer = ByteBuffer.allocate(0)
31
+    private var outgoingAppBuffer = ByteBuffer.allocate(0)
32
+    private var incomingAppBuffer = ByteBuffer.allocate(0)
33
+    private val outgoingAppBufferMutex = Mutex(false)
34
+
35
+    private var writeChannel = ByteChannel(autoFlush = true)
36
+
37
+    override val write: ByteWriteChannel
38
+        get() = writeChannel
39
+
40
+    override val isOpen: Boolean
41
+        get() = socket.isOpen
42
+
43
+    override fun bind(socketAddress: SocketAddress) {
44
+        socket.bind(socketAddress)
45
+    }
46
+
47
+    override suspend fun connect(socketAddress: SocketAddress) {
48
+        writeChannel = ByteChannel(autoFlush = true)
49
+
50
+        engine = sslContext.createSSLEngine().apply {
51
+            useClientMode = true
52
+        }
53
+
54
+        incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
55
+        outgoingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
56
+        incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
57
+
58
+        socket.connect(socketAddress)
59
+
60
+        engine.beginHandshake()
61
+
62
+        sslLoop()
63
+    }
64
+
65
+    private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
66
+        var result: SSLEngineResult? = initialResult
67
+        var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
68
+        while (true) {
69
+            when (handshakeStatus) {
70
+                SSLEngineResult.HandshakeStatus.NEED_TASK -> {
71
+                    engine.delegatedTask.run()
72
+                    handshakeStatus = engine.handshakeStatus
73
+                }
74
+                SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
75
+                    result = wrap()
76
+                    handshakeStatus = result?.handshakeStatus
77
+                }
78
+
79
+                SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
80
+                    result = unwrap()
81
+                    handshakeStatus = result?.handshakeStatus
82
+                }
83
+
84
+                SSLEngineResult.HandshakeStatus.FINISHED -> {
85
+                    val certs = engine.session.peerCertificates
86
+                    if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
87
+                        throw CertificateException("Certificate is not valid for $hostname")
88
+                    }
89
+                    scope.launch { readLoop() }
90
+                    scope.launch { writeLoop() }
91
+                    return
92
+                }
93
+
94
+                else -> return
95
+            }
96
+        }
97
+    }
98
+
99
+    override suspend fun read(buffer: ByteBuffer) = outgoingAppBufferMutex.withLock<Int> {
100
+        outgoingAppBuffer.flip()
101
+        val bytes = outgoingAppBuffer.limit()
102
+        buffer.put(outgoingAppBuffer)
103
+        outgoingAppBuffer.clear()
104
+        return bytes
105
+    }
106
+
107
+    private suspend fun wrap(): SSLEngineResult? {
108
+        var result: SSLEngineResult? = null
109
+        defaultPool.borrow { netBuffer ->
110
+            if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
111
+                writeChannel.readAvailable(incomingAppBuffer)
112
+            }
113
+            incomingAppBuffer.flip()
114
+            result = engine.wrap(incomingAppBuffer, netBuffer)
115
+            incomingAppBuffer.compact()
116
+
117
+            netBuffer.flip()
118
+            socket.write.writeFully(netBuffer)
119
+        }
120
+        return result
121
+    }
122
+
123
+    private suspend fun unwrap(): SSLEngineResult? {
124
+        if (incomingNetBuffer.position() == 0) {
125
+            val bytes = socket.read(incomingNetBuffer.slice())
126
+            if (bytes == -1) {
127
+                close()
128
+                return null
129
+            }
130
+            incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
131
+        }
132
+
133
+        incomingNetBuffer.flip()
134
+        outgoingAppBufferMutex.withLock {
135
+            val result = engine.unwrap(incomingNetBuffer, outgoingAppBuffer)
136
+            incomingNetBuffer.compact()
137
+            return result
138
+        }
139
+    }
140
+
141
+    override fun close() {
142
+        socket.close()
143
+    }
144
+
145
+    private suspend fun readLoop() {
146
+        while (socket.isOpen) {
147
+            sslLoop(unwrap())
148
+        }
149
+    }
150
+
151
+    private suspend fun writeLoop() {
152
+        while (socket.isOpen) {
153
+            sslLoop(wrap())
154
+        }
155
+    }
156
+
157
+}
158
+
159
+internal fun X509Certificate.validFor(host: String): Boolean {
160
+    val hostParts = host.split('.')
161
+    return allNames
162
+            .map { it.split('.') }
163
+            .filter { it.size == hostParts.size }
164
+            .filter { it[0].wildCardMatches(hostParts[0]) }
165
+            .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
166
+}
167
+
168
+private fun String.wildCardMatches(host: String) =
169
+        count { it == '*' } <= 1 &&
170
+                host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
171
+
172
+private val X509Certificate.allNames: Sequence<String>
173
+    get() = sequence {
174
+        commonName?.let { yield(it) }
175
+        yieldAll(subjectAlternateNames)
176
+    }
177
+
178
+private val X509Certificate.subjectAlternateNames: Set<String>
179
+    get() = nullOnThrow {
180
+        subjectAlternativeNames
181
+                ?.filter { it[0] == 2 }
182
+                ?.map { it[1].toString() }
183
+                ?.toSet()
184
+    } ?: emptySet()
185
+
186
+private val X509Certificate.commonName: String?
187
+    get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
188
+
189
+private val X509Certificate.rdns: Map<String, List<Rdn>>
190
+    get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
191
+
192
+private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
193
+    block()
194
+} catch (ex: Throwable) {
195
+    null
196
+}

+ 0
- 2
src/test/kotlin/com/dmdirc/ktirc/IrcClientImplTest.kt View File

@@ -7,7 +7,6 @@ import com.dmdirc.ktirc.messages.tagMap
7 7
 import com.dmdirc.ktirc.model.*
8 8
 import com.dmdirc.ktirc.util.currentTimeProvider
9 9
 import com.dmdirc.ktirc.util.generateLabel
10
-import io.ktor.util.KtorExperimentalAPI
11 10
 import io.mockk.*
12 11
 import kotlinx.coroutines.*
13 12
 import kotlinx.coroutines.channels.Channel
@@ -22,7 +21,6 @@ import java.nio.channels.UnresolvedAddressException
22 21
 import java.security.cert.CertificateException
23 22
 import java.util.concurrent.atomic.AtomicReference
24 23
 
25
-@KtorExperimentalAPI
26 24
 @ExperimentalCoroutinesApi
27 25
 internal class IrcClientImplTest {
28 26
 

src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt → src/test/kotlin/com/dmdirc/ktirc/io/LineBufferedSocketImplTest.kt View File

@@ -1,14 +1,11 @@
1 1
 package com.dmdirc.ktirc.io
2 2
 
3
-import io.ktor.network.tls.certificates.generateCertificate
4
-import io.ktor.util.KtorExperimentalAPI
5 3
 import kotlinx.coroutines.*
6 4
 import org.junit.jupiter.api.Assertions.assertEquals
7 5
 import org.junit.jupiter.api.Assertions.assertNotNull
8 6
 import org.junit.jupiter.api.Test
9 7
 import org.junit.jupiter.api.parallel.Execution
10 8
 import org.junit.jupiter.api.parallel.ExecutionMode
11
-import java.io.File
12 9
 import java.net.ServerSocket
13 10
 import java.security.KeyStore
14 11
 import java.security.cert.X509Certificate
@@ -16,15 +13,14 @@ import javax.net.ssl.KeyManagerFactory
16 13
 import javax.net.ssl.SSLContext
17 14
 import javax.net.ssl.X509TrustManager
18 15
 
19
-@KtorExperimentalAPI
20 16
 @ExperimentalCoroutinesApi
21 17
 @Execution(ExecutionMode.SAME_THREAD)
22
-internal class KtorLineBufferedSocketTest {
18
+internal class LineBufferedSocketImplTest {
23 19
 
24 20
     @Test
25 21
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
26 22
         ServerSocket(12321).use { serverSocket ->
27
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
23
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
28 24
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
29 25
 
30 26
             socket.connect()
@@ -36,7 +32,7 @@ internal class KtorLineBufferedSocketTest {
36 32
     @Test
37 33
     fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
38 34
         ServerSocket(12321).use { serverSocket ->
39
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
35
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
40 36
             val clientBytesAsync = GlobalScope.async {
41 37
                 ByteArray(13).apply {
42 38
                     serverSocket.accept().getInputStream().read(this)
@@ -55,7 +51,7 @@ internal class KtorLineBufferedSocketTest {
55 51
     @Test
56 52
     fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
57 53
         tlsServerSocket(12321).use { serverSocket ->
58
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321, true)
54
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321, true)
59 55
             socket.tlsTrustManager = getTrustingManager()
60 56
             val clientBytesAsync = GlobalScope.async {
61 57
                 ByteArray(13).apply {
@@ -75,7 +71,7 @@ internal class KtorLineBufferedSocketTest {
75 71
     @Test
76 72
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
77 73
         ServerSocket(12321).use { serverSocket ->
78
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
74
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
79 75
             GlobalScope.launch {
80 76
                 serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray())
81 77
             }
@@ -88,7 +84,7 @@ internal class KtorLineBufferedSocketTest {
88 84
     @Test
89 85
     fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
90 86
         ServerSocket(12321).use { serverSocket ->
91
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
87
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
92 88
             GlobalScope.launch {
93 89
                 serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray())
94 90
             }
@@ -101,7 +97,7 @@ internal class KtorLineBufferedSocketTest {
101 97
     @Test
102 98
     fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
103 99
         ServerSocket(12321).use { serverSocket ->
104
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
100
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
105 101
             GlobalScope.launch {
106 102
                 serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray())
107 103
             }
@@ -116,7 +112,7 @@ internal class KtorLineBufferedSocketTest {
116 112
     @Test
117 113
     fun `KtorLineBufferedSocket can receive multiple long lines of text`() = runBlocking {
118 114
         ServerSocket(12321).use { serverSocket ->
119
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
115
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
120 116
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
121 117
             val line2 = "1234567890987654321[];'#,.".repeat(500)
122 118
             val line3 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(500)
@@ -135,7 +131,7 @@ internal class KtorLineBufferedSocketTest {
135 131
     @Test
136 132
     fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
137 133
         ServerSocket(12321).use { serverSocket ->
138
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
134
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
139 135
             GlobalScope.launch {
140 136
                 with(serverSocket.accept().getOutputStream()) {
141 137
                     write("Hi".toByteArray())
@@ -156,7 +152,7 @@ internal class KtorLineBufferedSocketTest {
156 152
     @Test
157 153
     fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
158 154
         ServerSocket(12321).use { serverSocket ->
159
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
155
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
160 156
             GlobalScope.launch {
161 157
                 with(serverSocket.accept()) {
162 158
                     getOutputStream().write("Hi there\r\n".toByteArray())
@@ -173,7 +169,7 @@ internal class KtorLineBufferedSocketTest {
173 169
     @Test
174 170
     fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
175 171
         ServerSocket(12321).use { serverSocket ->
176
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
172
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
177 173
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
178 174
 
179 175
             socket.connect()
@@ -184,14 +180,11 @@ internal class KtorLineBufferedSocketTest {
184 180
     }
185 181
 
186 182
     private fun tlsServerSocket(port: Int): ServerSocket {
187
-        val keyFile = File.createTempFile("selfsigned", "jks")
188
-        generateCertificate(keyFile)
189
-
190
-        val keyStore = KeyStore.getInstance("JKS")
191
-        keyStore.load(keyFile.inputStream(), "changeit".toCharArray())
183
+        val keyStore = KeyStore.getInstance("PKCS12")
184
+        keyStore.load(LineBufferedSocketImplTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
192 185
 
193 186
         val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
194
-        keyManagerFactory.init(keyStore, "changeit".toCharArray())
187
+        keyManagerFactory.init(keyStore, CharArray(0))
195 188
 
196 189
         val sslContext = SSLContext.getInstance("TLSv1.2")
197 190
         sslContext.init(keyManagerFactory.keyManagers, null, null)

+ 152
- 0
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

@@ -0,0 +1,152 @@
1
+package com.dmdirc.ktirc.io
2
+
3
+import io.mockk.every
4
+import io.mockk.mockk
5
+import org.junit.jupiter.api.Assertions.assertFalse
6
+import org.junit.jupiter.api.Assertions.assertTrue
7
+import org.junit.jupiter.api.Test
8
+import java.security.cert.CertificateException
9
+import java.security.cert.X509Certificate
10
+
11
+internal class CertificateValidationTest {
12
+
13
+    private val cert = mockk<X509Certificate>()
14
+
15
+    @Test
16
+    fun `checks common name`() {
17
+        every { cert.subjectX500Principal } returns mockk {
18
+            every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
19
+        }
20
+
21
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
22
+        assertFalse(cert.validFor("subdomain2.test.ktirc"))
23
+        assertFalse(cert.validFor("testing"))
24
+    }
25
+
26
+    @Test
27
+    fun `checks common name with suffixed wildcard`() {
28
+        every { cert.subjectX500Principal } returns mockk {
29
+            every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
30
+        }
31
+
32
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
33
+        assertTrue(cert.validFor("subdomain2.test.ktirc"))
34
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
35
+        assertFalse(cert.validFor("1subdomain.test.ktirc"))
36
+    }
37
+
38
+    @Test
39
+    fun `checks common name with preixed wildcard`() {
40
+        every { cert.subjectX500Principal } returns mockk {
41
+            every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
42
+        }
43
+
44
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
45
+        assertTrue(cert.validFor("1subdomain.test.ktirc"))
46
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
47
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
48
+    }
49
+
50
+    @Test
51
+    fun `checks common name with infixed wildcard`() {
52
+        every { cert.subjectX500Principal } returns mockk {
53
+            every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
54
+        }
55
+
56
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
57
+        assertTrue(cert.validFor("SUB-domain.test.ktirc"))
58
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
59
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
60
+    }
61
+
62
+    @Test
63
+    fun `ignores wildcards in CN if they're not left-most`() {
64
+        every { cert.subjectX500Principal } returns mockk {
65
+            every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
66
+        }
67
+
68
+        assertFalse(cert.validFor("foo.domain.test.ktirc"))
69
+        assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
70
+        assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
71
+    }
72
+
73
+    @Test
74
+    fun `ignores wildcards in CN if there are too many`() {
75
+        every { cert.subjectX500Principal } returns mockk {
76
+            every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
77
+        }
78
+
79
+        assertFalse(cert.validFor("domain.test.ktirc"))
80
+        assertFalse(cert.validFor("subdomain.test.ktirc"))
81
+        assertFalse(cert.validFor("domain1.test.ktirc"))
82
+    }
83
+
84
+    @Test
85
+    fun `checks all sans`() {
86
+        every { cert.subjectAlternativeNames } returns listOf(
87
+                listOf(4, "directory.test.ktirc"),
88
+                listOf(2, "subdomain1.test.ktirc"),
89
+                listOf(2, "subdomain2.test.ktirc"),
90
+                listOf(2, "subdomain3.test.ktirc")
91
+        )
92
+
93
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
94
+        assertTrue(cert.validFor("subdomain2.test.KTIRC"))
95
+        assertTrue(cert.validFor("subdomain3.test.ktirc"))
96
+        assertFalse(cert.validFor("directory.test.ktirc"))
97
+    }
98
+
99
+    @Test
100
+    fun `checks wildcard sans`() {
101
+        every { cert.subjectAlternativeNames } returns listOf(
102
+                listOf(4, "directory.test.ktirc"),
103
+                listOf(2, "*domain1.test.ktirc"),
104
+                listOf(2, "subdomain*.test.ktirc"),
105
+                listOf(2, "*foo*.test.ktirc"),
106
+                listOf(2, "foo.*.ktirc")
107
+        )
108
+
109
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
110
+        assertTrue(cert.validFor("subdomain2.test.ktirc"))
111
+        assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
112
+        assertFalse(cert.validFor("foo.test.ktirc"))
113
+    }
114
+
115
+    @Test
116
+    fun `still uses CN if sans throws`() {
117
+        every { cert.subjectX500Principal } returns mockk {
118
+            every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
119
+        }
120
+        every { cert.subjectAlternativeNames } throws CertificateException("Oops")
121
+
122
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
123
+        assertFalse(cert.validFor("subdomain2.test.ktirc"))
124
+        assertFalse(cert.validFor("testing"))
125
+    }
126
+
127
+    @Test
128
+    fun `still uses sans if CN throws`() {
129
+        every { cert.subjectX500Principal } throws CertificateException("Oops")
130
+        every { cert.subjectAlternativeNames } returns listOf(
131
+                listOf(4, "directory.test.ktirc"),
132
+                listOf(2, "subdomain1.test.ktirc"),
133
+                listOf(2, "subdomain2.test.ktirc"),
134
+                listOf(2, "subdomain3.test.ktirc")
135
+        )
136
+
137
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
138
+        assertTrue(cert.validFor("subdomain2.test.KTIRC"))
139
+        assertTrue(cert.validFor("subdomain3.test.ktirc"))
140
+        assertFalse(cert.validFor("directory.test.ktirc"))
141
+    }
142
+
143
+
144
+    @Test
145
+    fun `fails if CN and sans missing`() {
146
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
147
+        assertFalse(cert.validFor("subdomain2.test.KTIRC"))
148
+        assertFalse(cert.validFor("subdomain3.test.ktirc"))
149
+        assertFalse(cert.validFor("directory.test.ktirc"))
150
+    }
151
+
152
+}

Loading…
Cancel
Save