123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- package com.dmdirc.ktirc.io
-
- import io.mockk.every
- import io.mockk.mockk
- import kotlinx.coroutines.*
- import kotlinx.coroutines.io.writeFully
- import kotlinx.io.core.String
- import org.junit.jupiter.api.AfterEach
- import org.junit.jupiter.api.Assertions.*
- import org.junit.jupiter.api.BeforeEach
- import org.junit.jupiter.api.Test
- import org.junit.jupiter.api.parallel.Execution
- import org.junit.jupiter.api.parallel.ExecutionMode
- import java.net.InetSocketAddress
- import java.net.ServerSocket
- import java.security.KeyStore
- import java.security.cert.CertificateException
- import java.security.cert.X509Certificate
- import javax.net.ssl.KeyManagerFactory
- import javax.net.ssl.SSLContext
- import javax.net.ssl.X509TrustManager
- import kotlin.coroutines.CoroutineContext
-
- internal class CertificateValidationTest {
-
- private val cert = mockk<X509Certificate>()
-
- @Test
- fun `checks common name`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertTrue(cert.validFor("subdomain.test.ktirc"))
- assertFalse(cert.validFor("subdomain2.test.ktirc"))
- assertFalse(cert.validFor("testing"))
- }
-
- @Test
- fun `checks common name with suffixed wildcard`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertTrue(cert.validFor("subdomain.test.ktirc"))
- assertTrue(cert.validFor("subdomain2.test.ktirc"))
- assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
- assertFalse(cert.validFor("1subdomain.test.ktirc"))
- }
-
- @Test
- fun `checks common name with preixed wildcard`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertTrue(cert.validFor("subdomain.test.ktirc"))
- assertTrue(cert.validFor("1subdomain.test.ktirc"))
- assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
- assertFalse(cert.validFor("subdomain1.test.ktirc"))
- }
-
- @Test
- fun `checks common name with infixed wildcard`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertTrue(cert.validFor("subdomain.test.ktirc"))
- assertTrue(cert.validFor("SUB-domain.test.ktirc"))
- assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
- assertFalse(cert.validFor("subdomain1.test.ktirc"))
- }
-
- @Test
- fun `ignores wildcards in CN if they're not left-most`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertFalse(cert.validFor("foo.domain.test.ktirc"))
- assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
- assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
- }
-
- @Test
- fun `ignores wildcards in CN if there are too many`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
- }
-
- assertFalse(cert.validFor("domain.test.ktirc"))
- assertFalse(cert.validFor("subdomain.test.ktirc"))
- assertFalse(cert.validFor("domain1.test.ktirc"))
- }
-
- @Test
- fun `checks all sans`() {
- every { cert.subjectAlternativeNames } returns listOf(
- listOf(4, "directory.test.ktirc"),
- listOf(2, "subdomain1.test.ktirc"),
- listOf(2, "subdomain2.test.ktirc"),
- listOf(2, "subdomain3.test.ktirc")
- )
-
- assertTrue(cert.validFor("subdomain1.test.ktirc"))
- assertTrue(cert.validFor("subdomain2.test.KTIRC"))
- assertTrue(cert.validFor("subdomain3.test.ktirc"))
- assertFalse(cert.validFor("directory.test.ktirc"))
- }
-
- @Test
- fun `checks wildcard sans`() {
- every { cert.subjectAlternativeNames } returns listOf(
- listOf(4, "directory.test.ktirc"),
- listOf(2, "*domain1.test.ktirc"),
- listOf(2, "subdomain*.test.ktirc"),
- listOf(2, "*foo*.test.ktirc"),
- listOf(2, "foo.*.ktirc")
- )
-
- assertTrue(cert.validFor("subdomain1.test.ktirc"))
- assertTrue(cert.validFor("subdomain2.test.ktirc"))
- assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
- assertFalse(cert.validFor("foo.test.ktirc"))
- }
-
- @Test
- fun `still uses CN if sans throws`() {
- every { cert.subjectX500Principal } returns mockk {
- every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
- }
- every { cert.subjectAlternativeNames } throws CertificateException("Oops")
-
- assertTrue(cert.validFor("subdomain.test.ktirc"))
- assertFalse(cert.validFor("subdomain2.test.ktirc"))
- assertFalse(cert.validFor("testing"))
- }
-
- @Test
- fun `still uses sans if CN throws`() {
- every { cert.subjectX500Principal } throws CertificateException("Oops")
- every { cert.subjectAlternativeNames } returns listOf(
- listOf(4, "directory.test.ktirc"),
- listOf(2, "subdomain1.test.ktirc"),
- listOf(2, "subdomain2.test.ktirc"),
- listOf(2, "subdomain3.test.ktirc")
- )
-
- assertTrue(cert.validFor("subdomain1.test.ktirc"))
- assertTrue(cert.validFor("subdomain2.test.KTIRC"))
- assertTrue(cert.validFor("subdomain3.test.ktirc"))
- assertFalse(cert.validFor("directory.test.ktirc"))
- }
-
-
- @Test
- fun `fails if CN and sans missing`() {
- assertFalse(cert.validFor("subdomain1.test.ktirc"))
- assertFalse(cert.validFor("subdomain2.test.KTIRC"))
- assertFalse(cert.validFor("subdomain3.test.ktirc"))
- assertFalse(cert.validFor("directory.test.ktirc"))
- }
-
- }
-
- @Suppress("BlockingMethodInNonBlockingContext")
- @Execution(ExecutionMode.SAME_THREAD)
- internal class TlsSocketTest: CoroutineScope {
-
- override var coroutineContext: CoroutineContext = GlobalScope.coroutineContext
-
- @ObsoleteCoroutinesApi
- @BeforeEach
- fun setup() {
- coroutineContext = newFixedThreadPoolContext(4, "tls-test")
- }
-
- @AfterEach
- fun teardown() {
- coroutineContext.cancel()
- }
-
- @Test
- fun `can send a string to a server over TLS`() = runBlocking(coroutineContext) {
- withTimeout(5000) {
- tlsServerSocket(12321).use { serverSocket ->
- val plainSocket = PlainTextSocket(this@TlsSocketTest)
- val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
- val clientBytesAsync = this@TlsSocketTest.async {
- ByteArray(13).apply {
- serverSocket.accept().getInputStream().read(this)
- }
- }
-
- tlsSocket.connect(InetSocketAddress("localhost", 12321))
- tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
-
- val bytes = clientBytesAsync.await()
- assertNotNull(bytes)
- assertEquals("Hello World\r\n", String(bytes))
- }
- }
- }
-
- @Test
- fun `can read a string from a server over TLS`() = runBlocking<Unit>(coroutineContext) {
- withTimeout(5000) {
- tlsServerSocket(12321).use { serverSocket ->
- val plainSocket = PlainTextSocket(this@TlsSocketTest)
- val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
- val socket = this@TlsSocketTest.async {
- serverSocket.accept().apply {
- this@TlsSocketTest.launch {
- getInputStream().read()
- }
- }
- }
-
- tlsSocket.connect(InetSocketAddress("localhost", 12321))
-
- this@TlsSocketTest.launch {
- with(socket.await().getOutputStream()) {
- write("Hack the planet!".toByteArray())
- flush()
- }
- }
-
- val buffer = tlsSocket.read()
-
- assertNotNull(buffer)
- buffer?.let {
- assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
- }
- }
- }
- }
-
- @Test
- fun `read returns null after close`() = runBlocking(coroutineContext) {
- withTimeout(5000) {
- tlsServerSocket(12321).use { serverSocket ->
- val plainSocket = PlainTextSocket(this@TlsSocketTest)
- val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
- this@TlsSocketTest.launch {
- serverSocket.accept().getInputStream().read()
- }
-
- tlsSocket.connect(InetSocketAddress("localhost", 12321))
-
- tlsSocket.close()
-
- val buffer = tlsSocket.read()
-
- assertNull(buffer)
- }
- }
- }
-
- @Test
- fun `throws if the hostname mismatches`() {
- tlsServerSocket(12321).use { serverSocket ->
- val plainSocket = PlainTextSocket(this@TlsSocketTest)
- val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "127.0.0.1")
- launch {
- serverSocket.accept().getInputStream().read()
- }
-
- runBlocking(coroutineContext) {
- withTimeout(5000) {
- try {
- tlsSocket.connect(InetSocketAddress("localhost", 12321))
- fail<Unit>("Expected an exception")
- } catch (ex: Exception) {
- assertTrue(ex is CertificateException)
- }
- }
- }
- }
- }
- }
-
- internal fun tlsServerSocket(port: Int): ServerSocket {
- val keyStore = KeyStore.getInstance("PKCS12")
- keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
-
- val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
- keyManagerFactory.init(keyStore, CharArray(0))
-
- val sslContext = SSLContext.getInstance("TLSv1.2")
- sslContext.init(keyManagerFactory.keyManagers, null, null)
- return sslContext.serverSocketFactory.createServerSocket(port)
- }
-
- internal fun getTrustingContext() =
- SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
-
- internal fun getTrustingManager() = object : X509TrustManager {
- override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
-
- override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
-
- override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
- }
|