You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

LineBufferedSocket.kt 3.6KB

  1. package
  2. import com.dmdirc.ktirc.util.logger
  3. import kotlinx.coroutines.*
  4. import kotlinx.coroutines.channels.Channel
  5. import kotlinx.coroutines.channels.ReceiveChannel
  6. import kotlinx.coroutines.channels.SendChannel
  7. import kotlinx.coroutines.channels.produce
  8. import
  9. import
  10. import java.nio.ByteBuffer
  11. import
  12. import
  13. import
  14. import
  15. internal interface LineBufferedSocket {
  16. @Throws(CertificateException::class)
  17. fun connect()
  18. fun disconnect()
  19. val sendChannel: SendChannel<ByteArray>
  20. val receiveChannel: ReceiveChannel<ByteArray>
  21. }
  22. /**
  23. * Asynchronous socket that buffers incoming data and emits individual lines.
  24. */
  25. // TODO: Expose advanced TLS options
  26. @ExperimentalCoroutinesApi
  27. internal class LineBufferedSocketImpl(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
  28. companion object {
  29. const val CARRIAGE_RETURN = '\r'.toByte()
  30. const val LINE_FEED = '\n'.toByte()
  31. }
  32. override val coroutineContext = coroutineScope.newCoroutineContext(Dispatchers.IO)
  33. override val sendChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED)
  34. var tlsTrustManager: X509TrustManager? = null
  35. private val log by logger()
  36. private lateinit var socket: Socket
  37. private lateinit var writeChannel: ByteWriteChannel
  38. override fun connect() {
  39. { "Connecting..." }
  40. socket = PlainTextSocket(this)
  41. runBlocking {
  42. if (tls) {
  43. with (SSLContext.getInstance("TLSv1.2")) {
  44. init(null, tlsTrustManager?.let { arrayOf(it) }, SecureRandom.getInstanceStrong())
  45. socket = TlsSocket(this@LineBufferedSocketImpl, socket, this, host)
  46. }
  47. }
  48. socket.connect(InetSocketAddress(host, port))
  49. println("connected!")
  50. writeChannel = socket.write
  51. }
  52. launch { writeLines() }
  53. }
  54. override fun disconnect() {
  55. { "Disconnecting..." }
  56. socket.close()
  57. coroutineContext.cancel()
  58. }
  59. override val receiveChannel
  60. get() = produce {
  61. val lineBuffer = ByteArray(16384)
  62. var nextByteOffset = 0
  63. while (socket.isOpen) {
  64. var lineStart = 0
  65. val bytesRead = { position(nextByteOffset) })
  66. for (i in nextByteOffset until nextByteOffset + bytesRead) {
  67. if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
  68. if (lineStart < i) {
  69. val line = lineBuffer.sliceArray(lineStart until i)
  70. log.fine { "<<< ${String(line)}" }
  71. send(line)
  72. }
  73. lineStart = i + 1
  74. }
  75. }
  76. lineBuffer.copyInto(lineBuffer, 0, lineStart)
  77. nextByteOffset += bytesRead - lineStart
  78. }
  79. }
  80. private suspend fun writeLines() {
  81. for (line in sendChannel) {
  82. with(writeChannel) {
  83. log.fine { ">>> ${String(line)}" }
  84. writeAvailable(line, 0, line.size)
  85. writeByte(CARRIAGE_RETURN)
  86. writeByte(LINE_FEED)
  87. flush()
  88. }
  89. }
  90. }
  91. }