Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

LineBufferedSocket.kt 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import io.ktor.network.selector.ActorSelectorManager
  4. import io.ktor.network.sockets.Socket
  5. import io.ktor.network.sockets.aSocket
  6. import io.ktor.network.sockets.openReadChannel
  7. import io.ktor.network.sockets.openWriteChannel
  8. import io.ktor.network.tls.tls
  9. import kotlinx.coroutines.CoroutineScope
  10. import kotlinx.coroutines.Dispatchers
  11. import kotlinx.coroutines.ExperimentalCoroutinesApi
  12. import kotlinx.coroutines.GlobalScope
  13. import kotlinx.coroutines.channels.ReceiveChannel
  14. import kotlinx.coroutines.channels.produce
  15. import kotlinx.coroutines.io.ByteReadChannel
  16. import kotlinx.coroutines.io.ByteWriteChannel
  17. import kotlinx.coroutines.sync.Mutex
  18. import java.net.InetSocketAddress
  19. import java.security.SecureRandom
  20. import javax.net.ssl.X509TrustManager
  21. internal interface LineBufferedSocket {
  22. suspend fun connect()
  23. fun disconnect()
  24. suspend fun sendLine(line: ByteArray, offset: Int = 0, length: Int = line.size)
  25. suspend fun sendLine(line: String)
  26. fun readLines(coroutineScope: CoroutineScope): ReceiveChannel<ByteArray>
  27. }
  28. /**
  29. * Asynchronous socket that buffers incoming data and emits individual lines.
  30. */
  31. // TODO: Expose advanced TLS options
  32. internal class KtorLineBufferedSocket(private val host: String, private val port: Int, private val tls: Boolean = false): LineBufferedSocket {
  33. companion object {
  34. const val CARRIAGE_RETURN = '\r'.toByte()
  35. const val LINE_FEED = '\n'.toByte()
  36. }
  37. var tlsTrustManager: X509TrustManager? = null
  38. private val log by logger()
  39. private val writeLock = Mutex()
  40. private lateinit var socket: Socket
  41. private lateinit var readChannel: ByteReadChannel
  42. private lateinit var writeChannel: ByteWriteChannel
  43. @Suppress("EXPERIMENTAL_API_USAGE")
  44. override suspend fun connect() {
  45. log.info { "Connecting..." }
  46. socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
  47. if (tls) {
  48. // TODO: Figure out how exactly scopes work...
  49. socket = socket.tls(GlobalScope.coroutineContext, randomAlgorithm = SecureRandom.getInstanceStrong().algorithm, trustManager = tlsTrustManager)
  50. }
  51. readChannel = socket.openReadChannel()
  52. writeChannel = socket.openWriteChannel()
  53. }
  54. override fun disconnect() {
  55. log.info { "Disconnecting..." }
  56. socket.close()
  57. }
  58. override suspend fun sendLine(line: ByteArray, offset: Int, length: Int) {
  59. writeLock.lock()
  60. try {
  61. with(writeChannel) {
  62. log.fine { ">>> ${String(line, offset, length)}" }
  63. writeAvailable(line, offset, length)
  64. writeByte(CARRIAGE_RETURN)
  65. writeByte(LINE_FEED)
  66. flush()
  67. }
  68. } finally {
  69. writeLock.unlock()
  70. }
  71. }
  72. override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
  73. @ExperimentalCoroutinesApi
  74. override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
  75. val lineBuffer = ByteArray(4096)
  76. var index = 0
  77. while (!readChannel.isClosedForRead) {
  78. var start = index
  79. val count = readChannel.readAvailable(lineBuffer, index, lineBuffer.size - index)
  80. for (i in index until index + count) {
  81. if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
  82. if (start < i) {
  83. val line = lineBuffer.sliceArray(start until i)
  84. log.fine { "<<< ${String(line)}" }
  85. send(line)
  86. }
  87. start = i + 1
  88. }
  89. }
  90. lineBuffer.copyInto(lineBuffer, 0, start)
  91. index = count + index - start
  92. }
  93. }
  94. }