package com.dmdirc.ktirc.io import com.dmdirc.ktirc.util.logger import io.ktor.network.selector.ActorSelectorManager import io.ktor.network.sockets.Socket import io.ktor.network.sockets.aSocket import io.ktor.network.sockets.openReadChannel import io.ktor.network.sockets.openWriteChannel import io.ktor.network.tls.tls import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.produce import kotlinx.coroutines.io.ByteReadChannel import kotlinx.coroutines.io.ByteWriteChannel import java.net.InetSocketAddress import java.security.SecureRandom import javax.net.ssl.X509TrustManager internal interface LineBufferedSocket { suspend fun connect() fun disconnect() fun readLines(coroutineScope: CoroutineScope): ReceiveChannel suspend fun writeLines(channel: ReceiveChannel) } /** * Asynchronous socket that buffers incoming data and emits individual lines. */ // TODO: Expose advanced TLS options internal class KtorLineBufferedSocket(private val host: String, private val port: Int, private val tls: Boolean = false): LineBufferedSocket { companion object { const val CARRIAGE_RETURN = '\r'.toByte() const val LINE_FEED = '\n'.toByte() } var tlsTrustManager: X509TrustManager? = null private val log by logger() private lateinit var socket: Socket private lateinit var readChannel: ByteReadChannel private lateinit var writeChannel: ByteWriteChannel @Suppress("EXPERIMENTAL_API_USAGE") override suspend fun connect() { log.info { "Connecting..." } socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port)) if (tls) { // TODO: Figure out how exactly scopes work... socket = socket.tls(GlobalScope.coroutineContext, randomAlgorithm = SecureRandom.getInstanceStrong().algorithm, trustManager = tlsTrustManager) } readChannel = socket.openReadChannel() writeChannel = socket.openWriteChannel() } override fun disconnect() { log.info { "Disconnecting..." } socket.close() } @ExperimentalCoroutinesApi override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce { val lineBuffer = ByteArray(4096) var index = 0 while (!readChannel.isClosedForRead) { var start = index val count = readChannel.readAvailable(lineBuffer, index, lineBuffer.size - index) for (i in index until index + count) { if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) { if (start < i) { val line = lineBuffer.sliceArray(start until i) log.fine { "<<< ${String(line)}" } send(line) } start = i + 1 } } lineBuffer.copyInto(lineBuffer, 0, start) index = count + index - start } } override suspend fun writeLines(channel: ReceiveChannel) { for (line in channel) { with(writeChannel) { log.fine { ">>> ${String(line)}" } writeAvailable(line, 0, line.size) writeByte(CARRIAGE_RETURN) writeByte(LINE_FEED) flush() } } } }