package com.dmdirc.ktirc.io import io.ktor.network.tls.certificates.generateCertificate import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.async import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.ExecutionMode import java.io.File import java.net.ServerSocket import java.security.KeyStore import java.security.cert.CertificateException import java.security.cert.X509Certificate import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SSLContext import javax.net.ssl.X509TrustManager @Execution(ExecutionMode.SAME_THREAD) internal class KtorLineBufferedSocketTest { @Test fun `KtorLineBufferedSocket can connect to a server`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) val clientSocketAsync = GlobalScope.async { serverSocket.accept() } socket.connect() assertNotNull(clientSocketAsync.await()) } } @Test fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking { tlsServerSocket(12321).use { serverSocket -> try { val socket = KtorLineBufferedSocket("localhost", 12321) val clientSocketAsync = GlobalScope.async { serverSocket.accept() } socket.connect() assertNotNull(clientSocketAsync.await()) fail() } catch (ex : CertificateException) { // Expected } } } @Test fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) val clientBytesAsync = GlobalScope.async { ByteArray(13).apply { serverSocket.accept().getInputStream().read(this) } } socket.connect() socket.sendLine("Hello World".toByteArray()) val bytes = clientBytesAsync.await() assertNotNull(bytes) assertEquals("Hello World\r\n", String(bytes)) } } @Test fun `KtorLineBufferedSocket can send a string to a server`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) val clientBytesAsync = GlobalScope.async { ByteArray(13).apply { serverSocket.accept().getInputStream().read(this) } } socket.connect() socket.sendLine("Hello World") val bytes = clientBytesAsync.await() assertNotNull(bytes) assertEquals("Hello World\r\n", String(bytes)) } } @Test fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking { tlsServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321, true) socket.tlsTrustManager = getTrustingManager() val clientBytesAsync = GlobalScope.async { ByteArray(13).apply { serverSocket.accept().getInputStream().read(this) } } socket.connect() socket.sendLine("Hello World") val bytes = clientBytesAsync.await() assertNotNull(bytes) assertEquals("Hello World\r\n", String(bytes)) } } @Test fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) val clientBytesAsync = GlobalScope.async { ByteArray(7).apply { serverSocket.accept().getInputStream().read(this) } } socket.connect() socket.sendLine("Hello World".toByteArray(), 6, 5) val bytes = clientBytesAsync.await() assertNotNull(bytes) assertEquals("World\r\n", String(bytes)) } } @Test fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) GlobalScope.launch { serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray()) } socket.connect() assertEquals("Hi there", String(socket.readLines(GlobalScope).receive())) } } @Test fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) GlobalScope.launch { serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray()) } socket.connect() assertEquals("Hi there", String(socket.readLines(GlobalScope).receive())) } } @Test fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) GlobalScope.launch { serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray()) } socket.connect() val lineProducer = socket.readLines(GlobalScope) assertEquals("Hi there", String(lineProducer.receive())) assertEquals("This is a test", String(lineProducer.receive())) } } @Test fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) GlobalScope.launch { with(serverSocket.accept().getOutputStream()) { write("Hi".toByteArray()) flush() write(" t".toByteArray()) flush() write("here\r\n".toByteArray()) flush() } } socket.connect() val lineProducer = socket.readLines(GlobalScope) assertEquals("Hi there", String(lineProducer.receive())) } } @Test fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) GlobalScope.launch { with(serverSocket.accept()) { getOutputStream().write("Hi there\r\n".toByteArray()) close() } } socket.connect() val lineProducer = socket.readLines(GlobalScope) assertEquals("Hi there", String(lineProducer.receive())) } } @Test fun `KtorLineBufferedSocket disconnects from server`() = runBlocking { ServerSocket(12321).use { serverSocket -> val socket = KtorLineBufferedSocket("localhost", 12321) val clientSocketAsync = GlobalScope.async { serverSocket.accept() } socket.connect() socket.disconnect() assertEquals(-1, clientSocketAsync.await().getInputStream().read()) { "Server socket should EOF after KtorLineBufferedSocket disconnects" } } } private fun tlsServerSocket(port: Int): ServerSocket { val keyFile = File.createTempFile("selfsigned", "jks") generateCertificate(keyFile) val keyStore = KeyStore.getInstance("JKS") keyStore.load(keyFile.inputStream(), "changeit".toCharArray()) val keyManagerFactory = KeyManagerFactory.getInstance("PKIX") keyManagerFactory.init(keyStore, "changeit".toCharArray()) val sslContext = SSLContext.getInstance("TLSv1.2") sslContext.init(keyManagerFactory.keyManagers, null, null) return sslContext.serverSocketFactory.createServerSocket(port) } private fun getTrustingManager() = object : X509TrustManager { override fun getAcceptedIssuers(): Array = emptyArray() override fun checkClientTrusted(certs: Array, authType: String) {} override fun checkServerTrusted(certs: Array, authType: String) {} } }