Browse Source

Replace ktor with plain socket code.

Closes #14
tags/v1.0.0
Chris Smith 5 years ago
parent
commit
96449f98a1

+ 3
- 0
CHANGELOG View File

1
 vNEXT (in development)
1
 vNEXT (in development)
2
 
2
 
3
+ * Replaced Ktor dependency with custom socket handling, which fixes
4
+   fatal issue when connecting to servers over TLS that request a
5
+   client certificate.
3
  * Added NicknameChangeRequired event for the case when a nickname is
6
  * Added NicknameChangeRequired event for the case when a nickname is
4
    not allowed during connection and *MUST* be changed
7
    not allowed during connection and *MUST* be changed
5
 
8
 

+ 0
- 9
README.md View File

77
 
77
 
78
 ## Known issues / FAQ
78
 ## Known issues / FAQ
79
 
79
 
80
-### `java.lang.IllegalStateException: Check failed` when connecting to some servers
81
-
82
-This happens when the IRC server requests an optional client certificate (for use
83
-in SASL EXTERNAL auth, usually). At present there is no support for client
84
-certificates in the networking library used by KtIrc. This is fixed in the
85
-[upstream library](https://github.com/ktorio/ktor/issues/641) and will be included
86
-as soon as snapshot builds are available. There is no workaround other than using
87
-an insecure connection.
88
-
89
 ### KtIrc connects over IPv4 even when host has IPv6
80
 ### KtIrc connects over IPv4 even when host has IPv6
90
 
81
 
91
 This is an issue with the Java standard library. You can change its behaviour by
82
 This is an issue with the Java standard library. You can change its behaviour by

+ 2
- 2
build.gradle.kts View File

39
 dependencies {
39
 dependencies {
40
     implementation(kotlin("stdlib-jdk8", "1.3.21"))
40
     implementation(kotlin("stdlib-jdk8", "1.3.21"))
41
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
41
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.1.1")
42
-    implementation("io.ktor:ktor-network:1.1.3")
43
-    implementation("io.ktor:ktor-network-tls:1.1.3")
42
+    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-io-jvm:0.1.7")
43
+    compile(kotlin("reflect"))
44
 
44
 
45
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.4.0")
45
     testImplementation("org.junit.jupiter:junit-jupiter-api:5.4.0")
46
     testImplementation("org.junit.jupiter:junit-jupiter-params:5.4.0")
46
     testImplementation("org.junit.jupiter:junit-jupiter-params:5.4.0")

+ 1
- 1
docs/index.adoc View File

1099
 
1099
 
1100
 TODO
1100
 TODO
1101
 
1101
 
1102
-==== sendNickChange
1102
+=== sendNickChange
1103
 
1103
 
1104
 TODO
1104
 TODO
1105
 
1105
 

+ 1
- 1
src/itest/kotlin/com/dmdirc/irctest/cases/connection/capabilities/v32/ServerTime.kt View File

6
     steps {
6
     steps {
7
         expect("CAP LS 302")
7
         expect("CAP LS 302")
8
         send("CAP * LS :server-time")
8
         send("CAP * LS :server-time")
9
-        expect("CAP REQ :server-time")
9
+        expect("CAP REQ server-time")
10
         send("CAP * ACK :server-time")
10
         send("CAP * ACK :server-time")
11
         expect("CAP END")
11
         expect("CAP END")
12
     }
12
     }

+ 3
- 6
src/main/kotlin/com/dmdirc/ktirc/IrcClientImpl.kt View File

3
 import com.dmdirc.ktirc.events.*
3
 import com.dmdirc.ktirc.events.*
4
 import com.dmdirc.ktirc.events.handlers.eventHandlers
4
 import com.dmdirc.ktirc.events.handlers.eventHandlers
5
 import com.dmdirc.ktirc.events.mutators.eventMutators
5
 import com.dmdirc.ktirc.events.mutators.eventMutators
6
-import com.dmdirc.ktirc.io.KtorLineBufferedSocket
7
 import com.dmdirc.ktirc.io.LineBufferedSocket
6
 import com.dmdirc.ktirc.io.LineBufferedSocket
7
+import com.dmdirc.ktirc.io.LineBufferedSocketImpl
8
 import com.dmdirc.ktirc.io.MessageHandler
8
 import com.dmdirc.ktirc.io.MessageHandler
9
 import com.dmdirc.ktirc.io.MessageParser
9
 import com.dmdirc.ktirc.io.MessageParser
10
 import com.dmdirc.ktirc.messages.*
10
 import com.dmdirc.ktirc.messages.*
13
 import com.dmdirc.ktirc.util.currentTimeProvider
13
 import com.dmdirc.ktirc.util.currentTimeProvider
14
 import com.dmdirc.ktirc.util.generateLabel
14
 import com.dmdirc.ktirc.util.generateLabel
15
 import com.dmdirc.ktirc.util.logger
15
 import com.dmdirc.ktirc.util.logger
16
-import io.ktor.util.KtorExperimentalAPI
17
 import kotlinx.coroutines.*
16
 import kotlinx.coroutines.*
18
 import kotlinx.coroutines.channels.Channel
17
 import kotlinx.coroutines.channels.Channel
19
 import kotlinx.coroutines.channels.map
18
 import kotlinx.coroutines.channels.map
20
-import kotlinx.coroutines.time.withTimeoutOrNull
21
 import java.time.Duration
19
 import java.time.Duration
22
 import java.util.concurrent.atomic.AtomicBoolean
20
 import java.util.concurrent.atomic.AtomicBoolean
23
 import java.util.logging.Level
21
 import java.util.logging.Level
35
     override val coroutineContext = GlobalScope.newCoroutineContext(Dispatchers.IO)
33
     override val coroutineContext = GlobalScope.newCoroutineContext(Dispatchers.IO)
36
 
34
 
37
     @ExperimentalCoroutinesApi
35
     @ExperimentalCoroutinesApi
38
-    @KtorExperimentalAPI
39
-    internal var socketFactory: (CoroutineScope, String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
36
+    internal var socketFactory: (CoroutineScope, String, Int, Boolean) -> LineBufferedSocket = ::LineBufferedSocketImpl
40
 
37
 
41
     internal var asyncTimeout = Duration.ofSeconds(20)
38
     internal var asyncTimeout = Duration.ofSeconds(20)
42
 
39
 
76
             send(tags, command, *arguments)
73
             send(tags, command, *arguments)
77
         }
74
         }
78
 
75
 
79
-        withTimeoutOrNull(asyncTimeout) {
76
+        withTimeoutOrNull(asyncTimeout.toMillis()) {
80
             channel.receive()
77
             channel.receive()
81
         }.also { serverState.asyncResponseState.pendingResponses.remove(label) }
78
         }.also { serverState.asyncResponseState.pendingResponses.remove(label) }
82
     }
79
     }

+ 15
- 21
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

1
 package com.dmdirc.ktirc.io
1
 package com.dmdirc.ktirc.io
2
 
2
 
3
 import com.dmdirc.ktirc.util.logger
3
 import com.dmdirc.ktirc.util.logger
4
-import io.ktor.network.selector.ActorSelectorManager
5
-import io.ktor.network.sockets.Socket
6
-import io.ktor.network.sockets.aSocket
7
-import io.ktor.network.sockets.openReadChannel
8
-import io.ktor.network.sockets.openWriteChannel
9
-import io.ktor.network.tls.tls
10
-import io.ktor.util.KtorExperimentalAPI
11
 import kotlinx.coroutines.*
4
 import kotlinx.coroutines.*
12
 import kotlinx.coroutines.channels.Channel
5
 import kotlinx.coroutines.channels.Channel
13
 import kotlinx.coroutines.channels.ReceiveChannel
6
 import kotlinx.coroutines.channels.ReceiveChannel
14
 import kotlinx.coroutines.channels.SendChannel
7
 import kotlinx.coroutines.channels.SendChannel
15
 import kotlinx.coroutines.channels.produce
8
 import kotlinx.coroutines.channels.produce
16
-import kotlinx.coroutines.io.ByteReadChannel
17
 import kotlinx.coroutines.io.ByteWriteChannel
9
 import kotlinx.coroutines.io.ByteWriteChannel
18
 import java.net.InetSocketAddress
10
 import java.net.InetSocketAddress
11
+import java.nio.ByteBuffer
19
 import java.security.SecureRandom
12
 import java.security.SecureRandom
20
 import java.security.cert.CertificateException
13
 import java.security.cert.CertificateException
14
+import javax.net.ssl.SSLContext
21
 import javax.net.ssl.X509TrustManager
15
 import javax.net.ssl.X509TrustManager
22
 
16
 
23
 internal interface LineBufferedSocket {
17
 internal interface LineBufferedSocket {
36
  * Asynchronous socket that buffers incoming data and emits individual lines.
30
  * Asynchronous socket that buffers incoming data and emits individual lines.
37
  */
31
  */
38
 // TODO: Expose advanced TLS options
32
 // TODO: Expose advanced TLS options
39
-@KtorExperimentalAPI
40
 @ExperimentalCoroutinesApi
33
 @ExperimentalCoroutinesApi
41
-internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
34
+internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
42
 
35
 
43
     companion object {
36
     companion object {
44
         const val CARRIAGE_RETURN = '\r'.toByte()
37
         const val CARRIAGE_RETURN = '\r'.toByte()
53
     private val log by logger()
46
     private val log by logger()
54
 
47
 
55
     private lateinit var socket: Socket
48
     private lateinit var socket: Socket
56
-    private lateinit var readChannel: ByteReadChannel
57
     private lateinit var writeChannel: ByteWriteChannel
49
     private lateinit var writeChannel: ByteWriteChannel
58
 
50
 
59
     override fun connect() {
51
     override fun connect() {
52
+        log.info { "Connecting..." }
53
+        socket = PlainTextSocket(this)
54
+
60
         runBlocking {
55
         runBlocking {
61
-            log.info { "Connecting..." }
62
-            socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
63
             if (tls) {
56
             if (tls) {
64
-                socket = socket.tls(
65
-                        coroutineContext = this@KtorLineBufferedSocket.coroutineContext,
66
-                        randomAlgorithm = SecureRandom.getInstanceStrong().algorithm,
67
-                        trustManager = tlsTrustManager)
57
+                with (SSLContext.getInstance("TLSv1.2")) {
58
+                    init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
59
+                    socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
60
+                }
68
             }
61
             }
69
-            readChannel = socket.openReadChannel()
70
-            writeChannel = socket.openWriteChannel()
62
+            socket.connect(InetSocketAddress(host, port))
63
+            println("connected!")
64
+            writeChannel = socket.write
71
         }
65
         }
72
         launch { writeLines() }
66
         launch { writeLines() }
73
     }
67
     }
82
         get() = produce {
76
         get() = produce {
83
             val lineBuffer = ByteArray(16384)
77
             val lineBuffer = ByteArray(16384)
84
             var nextByteOffset = 0
78
             var nextByteOffset = 0
85
-            while (!readChannel.isClosedForRead) {
79
+            while (socket.isOpen) {
86
                 var lineStart = 0
80
                 var lineStart = 0
87
-                val bytesRead = readChannel.readAvailable(lineBuffer, nextByteOffset, lineBuffer.size - nextByteOffset)
81
+                val bytesRead = socket.read(ByteBuffer.wrap(lineBuffer).apply { position(nextByteOffset) })
88
                 for (i in nextByteOffset until nextByteOffset + bytesRead) {
82
                 for (i in nextByteOffset until nextByteOffset + bytesRead) {
89
                     if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
83
                     if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
90
                         if (lineStart < i) {
84
                         if (lineStart < i) {

+ 148
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Sockets.kt View File

1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.coroutines.CancellableContinuation
4
+import kotlinx.coroutines.CoroutineScope
5
+import kotlinx.coroutines.io.ByteChannel
6
+import kotlinx.coroutines.io.ByteWriteChannel
7
+import kotlinx.coroutines.io.close
8
+import kotlinx.coroutines.launch
9
+import kotlinx.coroutines.suspendCancellableCoroutine
10
+import kotlinx.io.pool.DefaultPool
11
+import java.net.SocketAddress
12
+import java.nio.ByteBuffer
13
+import java.nio.channels.*
14
+import kotlin.coroutines.resume
15
+import kotlin.coroutines.resumeWithException
16
+
17
+internal const val BUFFER_SIZE = 32768
18
+internal const val POOL_SIZE = 16
19
+
20
+internal val defaultPool = ByteBufferPool()
21
+
22
+internal class ByteBufferPool : DefaultPool<ByteBuffer>(POOL_SIZE) {
23
+    override fun produceInstance(): ByteBuffer = ByteBuffer.allocate(BUFFER_SIZE)
24
+    override fun clearInstance(instance: ByteBuffer): ByteBuffer = instance.apply { clear() }
25
+
26
+    inline fun <T> borrow(block: (ByteBuffer) -> T): T {
27
+        val buffer = borrow()
28
+        try {
29
+            return block(buffer)
30
+        } finally {
31
+            recycle(buffer)
32
+        }
33
+    }
34
+}
35
+
36
+internal interface Socket {
37
+    fun bind(socketAddress: SocketAddress)
38
+    suspend fun connect(socketAddress: SocketAddress)
39
+    suspend fun read(buffer: ByteBuffer): Int
40
+    fun close()
41
+    val write: ByteWriteChannel
42
+    val isOpen: Boolean
43
+}
44
+
45
+internal class PlainTextSocket(private val scope: CoroutineScope) : Socket {
46
+
47
+    private val client = AsynchronousSocketChannel.open()
48
+    private var writeChannel = ByteChannel(autoFlush = true)
49
+
50
+    override val write: ByteWriteChannel
51
+        get() = writeChannel
52
+
53
+    override val isOpen: Boolean
54
+        get() = client.isOpen
55
+
56
+    override fun bind(socketAddress: SocketAddress) {
57
+        client.bind(socketAddress)
58
+    }
59
+
60
+    override suspend fun connect(socketAddress: SocketAddress) {
61
+        writeChannel = ByteChannel(autoFlush = true)
62
+
63
+        suspendCancellableCoroutine<Unit> { continuation ->
64
+            client.closeOnCancel(continuation)
65
+            client.connect(socketAddress, continuation, AsyncVoidIOHandler)
66
+        }
67
+
68
+        scope.launch { writeLoop() }
69
+    }
70
+
71
+    override fun close() {
72
+        writeChannel.close()
73
+        client.close()
74
+    }
75
+
76
+    override suspend fun read(buffer: ByteBuffer) = try {
77
+        val bytes = suspendCancellableCoroutine<Int> { continuation ->
78
+
79
+            client.closeOnCancel(continuation)
80
+            client.read(buffer, continuation, asyncIOHandler())
81
+        }
82
+
83
+        if (bytes == -1) {
84
+            close()
85
+        }
86
+        bytes
87
+    } catch (_: ClosedChannelException) {
88
+        // Ignore
89
+        0
90
+    }
91
+
92
+    private suspend fun writeLoop() {
93
+        while (client.isOpen) {
94
+            defaultPool.borrow { buffer ->
95
+                writeChannel.readAvailable(buffer)
96
+                buffer.flip()
97
+                try {
98
+                    suspendCancellableCoroutine<Int> { continuation ->
99
+                        client.closeOnCancel(continuation)
100
+                        client.write(buffer, continuation, asyncIOHandler())
101
+                    }
102
+                } catch (_: ClosedChannelException) {
103
+                    // Ignore
104
+                }
105
+            }
106
+        }
107
+    }
108
+
109
+}
110
+
111
+private fun Channel.closeOnCancel(cont: CancellableContinuation<*>) {
112
+    cont.invokeOnCancellation {
113
+        try {
114
+            close()
115
+        } catch (ex: Throwable) {
116
+            // Specification says that it is Ok to call it any time, but reality is different,
117
+            // so we have just to ignore exception
118
+        }
119
+    }
120
+}
121
+
122
+@Suppress("UNCHECKED_CAST")
123
+private fun <T> asyncIOHandler(): CompletionHandler<T, CancellableContinuation<T>> =
124
+        AsyncIOHandlerAny as CompletionHandler<T, CancellableContinuation<T>>
125
+
126
+private object AsyncIOHandlerAny : CompletionHandler<Any, CancellableContinuation<Any>> {
127
+    override fun completed(result: Any, cont: CancellableContinuation<Any>) {
128
+        cont.resume(result)
129
+    }
130
+
131
+    override fun failed(ex: Throwable, cont: CancellableContinuation<Any>) {
132
+        // just return if already cancelled and got an expected exception for that case
133
+        if (ex is AsynchronousCloseException && cont.isCancelled) return
134
+        cont.resumeWithException(ex)
135
+    }
136
+}
137
+
138
+private object AsyncVoidIOHandler : CompletionHandler<Void?, CancellableContinuation<Unit>> {
139
+    override fun completed(result: Void?, cont: CancellableContinuation<Unit>) {
140
+        cont.resume(Unit)
141
+    }
142
+
143
+    override fun failed(ex: Throwable, cont: CancellableContinuation<Unit>) {
144
+        // just return if already cancelled and got an expected exception for that case
145
+        if (ex is AsynchronousCloseException && cont.isCancelled) return
146
+        cont.resumeWithException(ex)
147
+    }
148
+}

+ 196
- 0
src/main/kotlin/com/dmdirc/ktirc/io/Tls.kt View File

1
+package com.dmdirc.ktirc.io
2
+
3
+import kotlinx.coroutines.CoroutineScope
4
+import kotlinx.coroutines.io.ByteChannel
5
+import kotlinx.coroutines.io.ByteWriteChannel
6
+import kotlinx.coroutines.launch
7
+import kotlinx.coroutines.sync.Mutex
8
+import kotlinx.coroutines.sync.withLock
9
+import java.net.SocketAddress
10
+import java.nio.ByteBuffer
11
+import java.security.cert.CertificateException
12
+import java.security.cert.X509Certificate
13
+import java.util.regex.Pattern
14
+import javax.naming.ldap.LdapName
15
+import javax.naming.ldap.Rdn
16
+import javax.net.ssl.SSLContext
17
+import javax.net.ssl.SSLEngine
18
+import javax.net.ssl.SSLEngineResult
19
+
20
+
21
+internal class TlsSocket(
22
+        private val scope: CoroutineScope,
23
+        private val socket: Socket,
24
+        private val sslContext: SSLContext,
25
+        private val hostname: String
26
+) : Socket {
27
+
28
+    private var engine: SSLEngine = sslContext.createSSLEngine()
29
+
30
+    private var incomingNetBuffer = ByteBuffer.allocate(0)
31
+    private var outgoingAppBuffer = ByteBuffer.allocate(0)
32
+    private var incomingAppBuffer = ByteBuffer.allocate(0)
33
+    private val outgoingAppBufferMutex = Mutex(false)
34
+
35
+    private var writeChannel = ByteChannel(autoFlush = true)
36
+
37
+    override val write: ByteWriteChannel
38
+        get() = writeChannel
39
+
40
+    override val isOpen: Boolean
41
+        get() = socket.isOpen
42
+
43
+    override fun bind(socketAddress: SocketAddress) {
44
+        socket.bind(socketAddress)
45
+    }
46
+
47
+    override suspend fun connect(socketAddress: SocketAddress) {
48
+        writeChannel = ByteChannel(autoFlush = true)
49
+
50
+        engine = sslContext.createSSLEngine().apply {
51
+            useClientMode = true
52
+        }
53
+
54
+        incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
55
+        outgoingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
56
+        incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
57
+
58
+        socket.connect(socketAddress)
59
+
60
+        engine.beginHandshake()
61
+
62
+        sslLoop()
63
+    }
64
+
65
+    private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
66
+        var result: SSLEngineResult? = initialResult
67
+        var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
68
+        while (true) {
69
+            when (handshakeStatus) {
70
+                SSLEngineResult.HandshakeStatus.NEED_TASK -> {
71
+                    engine.delegatedTask.run()
72
+                    handshakeStatus = engine.handshakeStatus
73
+                }
74
+                SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
75
+                    result = wrap()
76
+                    handshakeStatus = result?.handshakeStatus
77
+                }
78
+
79
+                SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
80
+                    result = unwrap()
81
+                    handshakeStatus = result?.handshakeStatus
82
+                }
83
+
84
+                SSLEngineResult.HandshakeStatus.FINISHED -> {
85
+                    val certs = engine.session.peerCertificates
86
+                    if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
87
+                        throw CertificateException("Certificate is not valid for $hostname")
88
+                    }
89
+                    scope.launch { readLoop() }
90
+                    scope.launch { writeLoop() }
91
+                    return
92
+                }
93
+
94
+                else -> return
95
+            }
96
+        }
97
+    }
98
+
99
+    override suspend fun read(buffer: ByteBuffer) = outgoingAppBufferMutex.withLock<Int> {
100
+        outgoingAppBuffer.flip()
101
+        val bytes = outgoingAppBuffer.limit()
102
+        buffer.put(outgoingAppBuffer)
103
+        outgoingAppBuffer.clear()
104
+        return bytes
105
+    }
106
+
107
+    private suspend fun wrap(): SSLEngineResult? {
108
+        var result: SSLEngineResult? = null
109
+        defaultPool.borrow { netBuffer ->
110
+            if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
111
+                writeChannel.readAvailable(incomingAppBuffer)
112
+            }
113
+            incomingAppBuffer.flip()
114
+            result = engine.wrap(incomingAppBuffer, netBuffer)
115
+            incomingAppBuffer.compact()
116
+
117
+            netBuffer.flip()
118
+            socket.write.writeFully(netBuffer)
119
+        }
120
+        return result
121
+    }
122
+
123
+    private suspend fun unwrap(): SSLEngineResult? {
124
+        if (incomingNetBuffer.position() == 0) {
125
+            val bytes = socket.read(incomingNetBuffer.slice())
126
+            if (bytes == -1) {
127
+                close()
128
+                return null
129
+            }
130
+            incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
131
+        }
132
+
133
+        incomingNetBuffer.flip()
134
+        outgoingAppBufferMutex.withLock {
135
+            val result = engine.unwrap(incomingNetBuffer, outgoingAppBuffer)
136
+            incomingNetBuffer.compact()
137
+            return result
138
+        }
139
+    }
140
+
141
+    override fun close() {
142
+        socket.close()
143
+    }
144
+
145
+    private suspend fun readLoop() {
146
+        while (socket.isOpen) {
147
+            sslLoop(unwrap())
148
+        }
149
+    }
150
+
151
+    private suspend fun writeLoop() {
152
+        while (socket.isOpen) {
153
+            sslLoop(wrap())
154
+        }
155
+    }
156
+
157
+}
158
+
159
+internal fun X509Certificate.validFor(host: String): Boolean {
160
+    val hostParts = host.split('.')
161
+    return allNames
162
+            .map { it.split('.') }
163
+            .filter { it.size == hostParts.size }
164
+            .filter { it[0].wildCardMatches(hostParts[0]) }
165
+            .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
166
+}
167
+
168
+private fun String.wildCardMatches(host: String) =
169
+        count { it == '*' } <= 1 &&
170
+                host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
171
+
172
+private val X509Certificate.allNames: Sequence<String>
173
+    get() = sequence {
174
+        commonName?.let { yield(it) }
175
+        yieldAll(subjectAlternateNames)
176
+    }
177
+
178
+private val X509Certificate.subjectAlternateNames: Set<String>
179
+    get() = nullOnThrow {
180
+        subjectAlternativeNames
181
+                ?.filter { it[0] == 2 }
182
+                ?.map { it[1].toString() }
183
+                ?.toSet()
184
+    } ?: emptySet()
185
+
186
+private val X509Certificate.commonName: String?
187
+    get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
188
+
189
+private val X509Certificate.rdns: Map<String, List<Rdn>>
190
+    get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
191
+
192
+private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
193
+    block()
194
+} catch (ex: Throwable) {
195
+    null
196
+}

+ 0
- 2
src/test/kotlin/com/dmdirc/ktirc/IrcClientImplTest.kt View File

7
 import com.dmdirc.ktirc.model.*
7
 import com.dmdirc.ktirc.model.*
8
 import com.dmdirc.ktirc.util.currentTimeProvider
8
 import com.dmdirc.ktirc.util.currentTimeProvider
9
 import com.dmdirc.ktirc.util.generateLabel
9
 import com.dmdirc.ktirc.util.generateLabel
10
-import io.ktor.util.KtorExperimentalAPI
11
 import io.mockk.*
10
 import io.mockk.*
12
 import kotlinx.coroutines.*
11
 import kotlinx.coroutines.*
13
 import kotlinx.coroutines.channels.Channel
12
 import kotlinx.coroutines.channels.Channel
22
 import java.security.cert.CertificateException
21
 import java.security.cert.CertificateException
23
 import java.util.concurrent.atomic.AtomicReference
22
 import java.util.concurrent.atomic.AtomicReference
24
 
23
 
25
-@KtorExperimentalAPI
26
 @ExperimentalCoroutinesApi
24
 @ExperimentalCoroutinesApi
27
 internal class IrcClientImplTest {
25
 internal class IrcClientImplTest {
28
 
26
 

src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt → src/test/kotlin/com/dmdirc/ktirc/io/LineBufferedSocketImplTest.kt View File

1
 package com.dmdirc.ktirc.io
1
 package com.dmdirc.ktirc.io
2
 
2
 
3
-import io.ktor.network.tls.certificates.generateCertificate
4
-import io.ktor.util.KtorExperimentalAPI
5
 import kotlinx.coroutines.*
3
 import kotlinx.coroutines.*
6
 import org.junit.jupiter.api.Assertions.assertEquals
4
 import org.junit.jupiter.api.Assertions.assertEquals
7
 import org.junit.jupiter.api.Assertions.assertNotNull
5
 import org.junit.jupiter.api.Assertions.assertNotNull
8
 import org.junit.jupiter.api.Test
6
 import org.junit.jupiter.api.Test
9
 import org.junit.jupiter.api.parallel.Execution
7
 import org.junit.jupiter.api.parallel.Execution
10
 import org.junit.jupiter.api.parallel.ExecutionMode
8
 import org.junit.jupiter.api.parallel.ExecutionMode
11
-import java.io.File
12
 import java.net.ServerSocket
9
 import java.net.ServerSocket
13
 import java.security.KeyStore
10
 import java.security.KeyStore
14
 import java.security.cert.X509Certificate
11
 import java.security.cert.X509Certificate
16
 import javax.net.ssl.SSLContext
13
 import javax.net.ssl.SSLContext
17
 import javax.net.ssl.X509TrustManager
14
 import javax.net.ssl.X509TrustManager
18
 
15
 
19
-@KtorExperimentalAPI
20
 @ExperimentalCoroutinesApi
16
 @ExperimentalCoroutinesApi
21
 @Execution(ExecutionMode.SAME_THREAD)
17
 @Execution(ExecutionMode.SAME_THREAD)
22
-internal class KtorLineBufferedSocketTest {
18
+internal class LineBufferedSocketImplTest {
23
 
19
 
24
     @Test
20
     @Test
25
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
21
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
26
         ServerSocket(12321).use { serverSocket ->
22
         ServerSocket(12321).use { serverSocket ->
27
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
23
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
28
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
24
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
29
 
25
 
30
             socket.connect()
26
             socket.connect()
36
     @Test
32
     @Test
37
     fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
33
     fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
38
         ServerSocket(12321).use { serverSocket ->
34
         ServerSocket(12321).use { serverSocket ->
39
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
35
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
40
             val clientBytesAsync = GlobalScope.async {
36
             val clientBytesAsync = GlobalScope.async {
41
                 ByteArray(13).apply {
37
                 ByteArray(13).apply {
42
                     serverSocket.accept().getInputStream().read(this)
38
                     serverSocket.accept().getInputStream().read(this)
55
     @Test
51
     @Test
56
     fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
52
     fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
57
         tlsServerSocket(12321).use { serverSocket ->
53
         tlsServerSocket(12321).use { serverSocket ->
58
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321, true)
54
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321, true)
59
             socket.tlsTrustManager = getTrustingManager()
55
             socket.tlsTrustManager = getTrustingManager()
60
             val clientBytesAsync = GlobalScope.async {
56
             val clientBytesAsync = GlobalScope.async {
61
                 ByteArray(13).apply {
57
                 ByteArray(13).apply {
75
     @Test
71
     @Test
76
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
72
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
77
         ServerSocket(12321).use { serverSocket ->
73
         ServerSocket(12321).use { serverSocket ->
78
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
74
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
79
             GlobalScope.launch {
75
             GlobalScope.launch {
80
                 serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray())
76
                 serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray())
81
             }
77
             }
88
     @Test
84
     @Test
89
     fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
85
     fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
90
         ServerSocket(12321).use { serverSocket ->
86
         ServerSocket(12321).use { serverSocket ->
91
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
87
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
92
             GlobalScope.launch {
88
             GlobalScope.launch {
93
                 serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray())
89
                 serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray())
94
             }
90
             }
101
     @Test
97
     @Test
102
     fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
98
     fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
103
         ServerSocket(12321).use { serverSocket ->
99
         ServerSocket(12321).use { serverSocket ->
104
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
100
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
105
             GlobalScope.launch {
101
             GlobalScope.launch {
106
                 serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray())
102
                 serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray())
107
             }
103
             }
116
     @Test
112
     @Test
117
     fun `KtorLineBufferedSocket can receive multiple long lines of text`() = runBlocking {
113
     fun `KtorLineBufferedSocket can receive multiple long lines of text`() = runBlocking {
118
         ServerSocket(12321).use { serverSocket ->
114
         ServerSocket(12321).use { serverSocket ->
119
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
115
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
120
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
116
             val line1 = "abcdefghijklmnopqrstuvwxyz".repeat(500)
121
             val line2 = "1234567890987654321[];'#,.".repeat(500)
117
             val line2 = "1234567890987654321[];'#,.".repeat(500)
122
             val line3 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(500)
118
             val line3 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(500)
135
     @Test
131
     @Test
136
     fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
132
     fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
137
         ServerSocket(12321).use { serverSocket ->
133
         ServerSocket(12321).use { serverSocket ->
138
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
134
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
139
             GlobalScope.launch {
135
             GlobalScope.launch {
140
                 with(serverSocket.accept().getOutputStream()) {
136
                 with(serverSocket.accept().getOutputStream()) {
141
                     write("Hi".toByteArray())
137
                     write("Hi".toByteArray())
156
     @Test
152
     @Test
157
     fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
153
     fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
158
         ServerSocket(12321).use { serverSocket ->
154
         ServerSocket(12321).use { serverSocket ->
159
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
155
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
160
             GlobalScope.launch {
156
             GlobalScope.launch {
161
                 with(serverSocket.accept()) {
157
                 with(serverSocket.accept()) {
162
                     getOutputStream().write("Hi there\r\n".toByteArray())
158
                     getOutputStream().write("Hi there\r\n".toByteArray())
173
     @Test
169
     @Test
174
     fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
170
     fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
175
         ServerSocket(12321).use { serverSocket ->
171
         ServerSocket(12321).use { serverSocket ->
176
-            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
172
+            val socket = LineBufferedSocketImpl(GlobalScope, "localhost", 12321)
177
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
173
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
178
 
174
 
179
             socket.connect()
175
             socket.connect()
184
     }
180
     }
185
 
181
 
186
     private fun tlsServerSocket(port: Int): ServerSocket {
182
     private fun tlsServerSocket(port: Int): ServerSocket {
187
-        val keyFile = File.createTempFile("selfsigned", "jks")
188
-        generateCertificate(keyFile)
189
-
190
-        val keyStore = KeyStore.getInstance("JKS")
191
-        keyStore.load(keyFile.inputStream(), "changeit".toCharArray())
183
+        val keyStore = KeyStore.getInstance("PKCS12")
184
+        keyStore.load(LineBufferedSocketImplTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
192
 
185
 
193
         val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
186
         val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
194
-        keyManagerFactory.init(keyStore, "changeit".toCharArray())
187
+        keyManagerFactory.init(keyStore, CharArray(0))
195
 
188
 
196
         val sslContext = SSLContext.getInstance("TLSv1.2")
189
         val sslContext = SSLContext.getInstance("TLSv1.2")
197
         sslContext.init(keyManagerFactory.keyManagers, null, null)
190
         sslContext.init(keyManagerFactory.keyManagers, null, null)

+ 152
- 0
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

1
+package com.dmdirc.ktirc.io
2
+
3
+import io.mockk.every
4
+import io.mockk.mockk
5
+import org.junit.jupiter.api.Assertions.assertFalse
6
+import org.junit.jupiter.api.Assertions.assertTrue
7
+import org.junit.jupiter.api.Test
8
+import java.security.cert.CertificateException
9
+import java.security.cert.X509Certificate
10
+
11
+internal class CertificateValidationTest {
12
+
13
+    private val cert = mockk<X509Certificate>()
14
+
15
+    @Test
16
+    fun `checks common name`() {
17
+        every { cert.subjectX500Principal } returns mockk {
18
+            every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
19
+        }
20
+
21
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
22
+        assertFalse(cert.validFor("subdomain2.test.ktirc"))
23
+        assertFalse(cert.validFor("testing"))
24
+    }
25
+
26
+    @Test
27
+    fun `checks common name with suffixed wildcard`() {
28
+        every { cert.subjectX500Principal } returns mockk {
29
+            every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
30
+        }
31
+
32
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
33
+        assertTrue(cert.validFor("subdomain2.test.ktirc"))
34
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
35
+        assertFalse(cert.validFor("1subdomain.test.ktirc"))
36
+    }
37
+
38
+    @Test
39
+    fun `checks common name with preixed wildcard`() {
40
+        every { cert.subjectX500Principal } returns mockk {
41
+            every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
42
+        }
43
+
44
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
45
+        assertTrue(cert.validFor("1subdomain.test.ktirc"))
46
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
47
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
48
+    }
49
+
50
+    @Test
51
+    fun `checks common name with infixed wildcard`() {
52
+        every { cert.subjectX500Principal } returns mockk {
53
+            every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
54
+        }
55
+
56
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
57
+        assertTrue(cert.validFor("SUB-domain.test.ktirc"))
58
+        assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
59
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
60
+    }
61
+
62
+    @Test
63
+    fun `ignores wildcards in CN if they're not left-most`() {
64
+        every { cert.subjectX500Principal } returns mockk {
65
+            every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
66
+        }
67
+
68
+        assertFalse(cert.validFor("foo.domain.test.ktirc"))
69
+        assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
70
+        assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
71
+    }
72
+
73
+    @Test
74
+    fun `ignores wildcards in CN if there are too many`() {
75
+        every { cert.subjectX500Principal } returns mockk {
76
+            every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
77
+        }
78
+
79
+        assertFalse(cert.validFor("domain.test.ktirc"))
80
+        assertFalse(cert.validFor("subdomain.test.ktirc"))
81
+        assertFalse(cert.validFor("domain1.test.ktirc"))
82
+    }
83
+
84
+    @Test
85
+    fun `checks all sans`() {
86
+        every { cert.subjectAlternativeNames } returns listOf(
87
+                listOf(4, "directory.test.ktirc"),
88
+                listOf(2, "subdomain1.test.ktirc"),
89
+                listOf(2, "subdomain2.test.ktirc"),
90
+                listOf(2, "subdomain3.test.ktirc")
91
+        )
92
+
93
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
94
+        assertTrue(cert.validFor("subdomain2.test.KTIRC"))
95
+        assertTrue(cert.validFor("subdomain3.test.ktirc"))
96
+        assertFalse(cert.validFor("directory.test.ktirc"))
97
+    }
98
+
99
+    @Test
100
+    fun `checks wildcard sans`() {
101
+        every { cert.subjectAlternativeNames } returns listOf(
102
+                listOf(4, "directory.test.ktirc"),
103
+                listOf(2, "*domain1.test.ktirc"),
104
+                listOf(2, "subdomain*.test.ktirc"),
105
+                listOf(2, "*foo*.test.ktirc"),
106
+                listOf(2, "foo.*.ktirc")
107
+        )
108
+
109
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
110
+        assertTrue(cert.validFor("subdomain2.test.ktirc"))
111
+        assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
112
+        assertFalse(cert.validFor("foo.test.ktirc"))
113
+    }
114
+
115
+    @Test
116
+    fun `still uses CN if sans throws`() {
117
+        every { cert.subjectX500Principal } returns mockk {
118
+            every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
119
+        }
120
+        every { cert.subjectAlternativeNames } throws CertificateException("Oops")
121
+
122
+        assertTrue(cert.validFor("subdomain.test.ktirc"))
123
+        assertFalse(cert.validFor("subdomain2.test.ktirc"))
124
+        assertFalse(cert.validFor("testing"))
125
+    }
126
+
127
+    @Test
128
+    fun `still uses sans if CN throws`() {
129
+        every { cert.subjectX500Principal } throws CertificateException("Oops")
130
+        every { cert.subjectAlternativeNames } returns listOf(
131
+                listOf(4, "directory.test.ktirc"),
132
+                listOf(2, "subdomain1.test.ktirc"),
133
+                listOf(2, "subdomain2.test.ktirc"),
134
+                listOf(2, "subdomain3.test.ktirc")
135
+        )
136
+
137
+        assertTrue(cert.validFor("subdomain1.test.ktirc"))
138
+        assertTrue(cert.validFor("subdomain2.test.KTIRC"))
139
+        assertTrue(cert.validFor("subdomain3.test.ktirc"))
140
+        assertFalse(cert.validFor("directory.test.ktirc"))
141
+    }
142
+
143
+
144
+    @Test
145
+    fun `fails if CN and sans missing`() {
146
+        assertFalse(cert.validFor("subdomain1.test.ktirc"))
147
+        assertFalse(cert.validFor("subdomain2.test.KTIRC"))
148
+        assertFalse(cert.validFor("subdomain3.test.ktirc"))
149
+        assertFalse(cert.validFor("directory.test.ktirc"))
150
+    }
151
+
152
+}

Loading…
Cancel
Save