You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

LineBufferedSocket.kt 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import kotlinx.coroutines.*
  4. import kotlinx.coroutines.channels.Channel
  5. import kotlinx.coroutines.channels.ReceiveChannel
  6. import kotlinx.coroutines.channels.SendChannel
  7. import kotlinx.coroutines.channels.produce
  8. import kotlinx.coroutines.io.ByteWriteChannel
  9. import kotlinx.io.core.String
  10. import java.net.InetSocketAddress
  11. import java.nio.ByteBuffer
  12. import java.security.SecureRandom
  13. import java.security.cert.CertificateException
  14. import javax.net.ssl.SSLContext
  15. import javax.net.ssl.X509TrustManager
  16. internal interface LineBufferedSocket {
  17. @Throws(CertificateException::class)
  18. fun connect()
  19. fun disconnect()
  20. val sendChannel: SendChannel<ByteArray>
  21. val receiveChannel: ReceiveChannel<ByteArray>
  22. }
  23. /**
  24. * Asynchronous socket that buffers incoming data and emits individual lines.
  25. */
  26. // TODO: Expose advanced TLS options
  27. @ExperimentalCoroutinesApi
  28. internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private val host: String, private val ip: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
  29. companion object {
  30. const val CARRIAGE_RETURN = '\r'.toByte()
  31. const val LINE_FEED = '\n'.toByte()
  32. }
  33. override val coroutineContext = coroutineScope.newCoroutineContext(Dispatchers.IO)
  34. override val sendChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED)
  35. var tlsTrustManager: X509TrustManager? = null
  36. private val log by logger()
  37. private lateinit var socket: Socket
  38. private lateinit var writeChannel: ByteWriteChannel
  39. override fun connect() {
  40. log.info { "Connecting..." }
  41. socket = PlainTextSocket(this)
  42. runBlocking {
  43. if (tls) {
  44. with (SSLContext.getInstance("TLSv1.2")) {
  45. init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
  46. socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
  47. }
  48. }
  49. socket.connect(InetSocketAddress(ip, port))
  50. writeChannel = socket.write
  51. }
  52. launch { writeLines() }
  53. }
  54. override fun disconnect() {
  55. log.info { "Disconnecting..." }
  56. socket.close()
  57. coroutineContext.cancel()
  58. }
  59. override val receiveChannel
  60. get() = produce {
  61. defaultPool.borrow { lineBuffer ->
  62. while (socket.isOpen) {
  63. defaultPool.borrow { buffer ->
  64. val bytesRead = socket.read(buffer)
  65. var lastLine = 0
  66. for (i in 0 until bytesRead) {
  67. if (buffer[i] == CARRIAGE_RETURN || buffer[i] == LINE_FEED) {
  68. val length = i - lastLine + lineBuffer.position()
  69. if (length > 1) {
  70. val output = ByteBuffer.allocate(length)
  71. lineBuffer.flip()
  72. output.put(lineBuffer)
  73. lineBuffer.clear()
  74. output.put(buffer.array(), lastLine, i - lastLine)
  75. log.fine { "<<< ${String(output.array())}" }
  76. send(output.array())
  77. }
  78. lastLine = i + 1
  79. }
  80. }
  81. lineBuffer.put(buffer.array(), lastLine, bytesRead - lastLine)
  82. }
  83. }
  84. }
  85. }
  86. private suspend fun writeLines() {
  87. for (line in sendChannel) {
  88. with(writeChannel) {
  89. log.fine { ">>> ${String(line)}" }
  90. writeAvailable(line, 0, line.size)
  91. writeByte(CARRIAGE_RETURN)
  92. writeByte(LINE_FEED)
  93. flush()
  94. }
  95. }
  96. }
  97. }