You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

LineBufferedSocket.kt 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import io.ktor.network.selector.ActorSelectorManager
  4. import io.ktor.network.sockets.Socket
  5. import io.ktor.network.sockets.aSocket
  6. import io.ktor.network.sockets.openReadChannel
  7. import io.ktor.network.sockets.openWriteChannel
  8. import io.ktor.network.tls.tls
  9. import io.ktor.util.KtorExperimentalAPI
  10. import kotlinx.coroutines.*
  11. import kotlinx.coroutines.channels.Channel
  12. import kotlinx.coroutines.channels.ReceiveChannel
  13. import kotlinx.coroutines.channels.SendChannel
  14. import kotlinx.coroutines.channels.produce
  15. import kotlinx.coroutines.io.ByteReadChannel
  16. import kotlinx.coroutines.io.ByteWriteChannel
  17. import java.net.InetSocketAddress
  18. import java.security.SecureRandom
  19. import java.security.cert.CertificateException
  20. import javax.net.ssl.X509TrustManager
  21. internal interface LineBufferedSocket {
  22. @Throws(CertificateException::class)
  23. fun connect()
  24. fun disconnect()
  25. val sendChannel: SendChannel<ByteArray>
  26. val receiveChannel: ReceiveChannel<ByteArray>
  27. }
  28. /**
  29. * Asynchronous socket that buffers incoming data and emits individual lines.
  30. */
  31. // TODO: Expose advanced TLS options
  32. @KtorExperimentalAPI
  33. @ExperimentalCoroutinesApi
  34. internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
  35. companion object {
  36. const val CARRIAGE_RETURN = '\r'.toByte()
  37. const val LINE_FEED = '\n'.toByte()
  38. }
  39. override val coroutineContext = coroutineScope.newCoroutineContext(Dispatchers.IO)
  40. override val sendChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED)
  41. var tlsTrustManager: X509TrustManager? = null
  42. private val log by logger()
  43. private lateinit var socket: Socket
  44. private lateinit var readChannel: ByteReadChannel
  45. private lateinit var writeChannel: ByteWriteChannel
  46. override fun connect() {
  47. runBlocking {
  48. log.info { "Connecting..." }
  49. socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
  50. if (tls) {
  51. socket = socket.tls(
  52. coroutineContext = this@KtorLineBufferedSocket.coroutineContext,
  53. randomAlgorithm = SecureRandom.getInstanceStrong().algorithm,
  54. trustManager = tlsTrustManager)
  55. }
  56. readChannel = socket.openReadChannel()
  57. writeChannel = socket.openWriteChannel()
  58. }
  59. launch { writeLines() }
  60. }
  61. override fun disconnect() {
  62. log.info { "Disconnecting..." }
  63. socket.close()
  64. coroutineContext.cancel()
  65. }
  66. override val receiveChannel
  67. get() = produce {
  68. val lineBuffer = ByteArray(4096)
  69. var index = 0
  70. while (!readChannel.isClosedForRead) {
  71. var start = index
  72. val count = readChannel.readAvailable(lineBuffer, index, lineBuffer.size - index)
  73. for (i in index until index + count) {
  74. if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
  75. if (start < i) {
  76. val line = lineBuffer.sliceArray(start until i)
  77. log.fine { "<<< ${String(line)}" }
  78. send(line)
  79. }
  80. start = i + 1
  81. }
  82. }
  83. lineBuffer.copyInto(lineBuffer, 0, start)
  84. index = count + index - start
  85. }
  86. }
  87. private suspend fun writeLines() {
  88. for (line in sendChannel) {
  89. with(writeChannel) {
  90. log.fine { ">>> ${String(line)}" }
  91. writeAvailable(line, 0, line.size)
  92. writeByte(CARRIAGE_RETURN)
  93. writeByte(LINE_FEED)
  94. flush()
  95. }
  96. }
  97. }
  98. }