You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tls.kt 7.2KB


  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import kotlinx.coroutines.CoroutineScope
  4. import kotlinx.coroutines.channels.Channel
  5. import kotlinx.coroutines.channels.ClosedReceiveChannelException
  6. import kotlinx.coroutines.io.ByteChannel
  7. import kotlinx.coroutines.io.ByteWriteChannel
  8. import kotlinx.coroutines.launch
  9. import kotlinx.io.pool.useInstance
  10. import java.net.SocketAddress
  11. import java.nio.ByteBuffer
  12. import java.security.cert.CertificateException
  13. import java.security.cert.X509Certificate
  14. import java.util.regex.Pattern
  15. import javax.naming.ldap.LdapName
  16. import javax.naming.ldap.Rdn
  17. import javax.net.ssl.SSLContext
  18. import javax.net.ssl.SSLEngine
  19. import javax.net.ssl.SSLEngineResult
  20. internal class TlsSocket(
  21. private val scope: CoroutineScope,
  22. private val socket: Socket,
  23. private val sslContext: SSLContext,
  24. private val hostname: String
  25. ) : Socket {
  26. private val log by logger()
  27. private var engine: SSLEngine = sslContext.createSSLEngine()
  28. private var incomingNetBuffer = ByteBuffer.allocate(0)
  29. private var incomingAppBuffer = ByteBuffer.allocate(0)
  30. private var outgoingAppBuffers = Channel<ByteBuffer>(capacity = Channel.UNLIMITED)
  31. private var writeChannel = ByteChannel(autoFlush = true)
  32. override val write: ByteWriteChannel
  33. get() = writeChannel
  34. override val isOpen: Boolean
  35. get() = socket.isOpen
  36. override fun bind(socketAddress: SocketAddress) {
  37. socket.bind(socketAddress)
  38. }
  39. override suspend fun connect(socketAddress: SocketAddress) {
  40. writeChannel = ByteChannel(autoFlush = true)
  41. engine = sslContext.createSSLEngine().apply {
  42. useClientMode = true
  43. }
  44. incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize + BUFFER_SIZE)
  45. outgoingAppBuffers = Channel(capacity = Channel.UNLIMITED)
  46. incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
  47. socket.connect(socketAddress)
  48. engine.beginHandshake()
  49. sslLoop()
  50. }
  51. private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
  52. var result: SSLEngineResult? = initialResult
  53. var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
  54. while (true) {
  55. when (handshakeStatus) {
  56. SSLEngineResult.HandshakeStatus.NEED_TASK -> {
  57. engine.delegatedTask.run()
  58. handshakeStatus = engine.handshakeStatus
  59. }
  60. SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
  61. result = wrap()
  62. handshakeStatus = result?.handshakeStatus
  63. }
  64. SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
  65. result = unwrap()
  66. handshakeStatus = result?.handshakeStatus
  67. }
  68. SSLEngineResult.HandshakeStatus.FINISHED -> {
  69. val certs = engine.session.peerCertificates
  70. if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
  71. throw CertificateException("Certificate is not valid for $hostname")
  72. }
  73. scope.launch { readLoop() }
  74. scope.launch { writeLoop() }
  75. return
  76. }
  77. else -> return
  78. }
  79. }
  80. }
  81. override suspend fun read() = try {
  82. outgoingAppBuffers.receive()
  83. } catch (_: ClosedReceiveChannelException) {
  84. null
  85. }
  86. private suspend fun wrap(): SSLEngineResult? {
  87. var result: SSLEngineResult? = null
  88. byteBufferPool.useInstance { netBuffer ->
  89. if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
  90. writeChannel.readAvailable(incomingAppBuffer)
  91. }
  92. incomingAppBuffer.flip()
  93. result = engine.wrap(incomingAppBuffer, netBuffer)
  94. incomingAppBuffer.compact()
  95. netBuffer.flip()
  96. socket.write.writeFully(netBuffer)
  97. }
  98. return result
  99. }
  100. private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
  101. if (networkRead) {
  102. val buffer = socket.read()
  103. if (buffer == null) {
  104. close()
  105. return null
  106. }
  107. incomingNetBuffer.put(buffer)
  108. byteBufferPool.recycle(buffer)
  109. }
  110. incomingNetBuffer.flip()
  111. val buffer = byteBufferPool.borrow()
  112. val result = engine.unwrap(incomingNetBuffer, buffer)
  113. incomingNetBuffer.compact()
  114. if (buffer.position() > 0) {
  115. buffer.flip()
  116. outgoingAppBuffers.send(buffer)
  117. } else {
  118. byteBufferPool.recycle(buffer)
  119. }
  120. return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
  121. // We didn't do a network read, but SSLEngine is unhappy; force a read.
  122. log.finest { "Incoming net buffer underflowed, forcing re-read" }
  123. unwrap(true)
  124. } else {
  125. result
  126. }
  127. }
  128. override fun close() {
  129. socket.close()
  130. // Release any buffers we've got queued up
  131. while (true) {
  132. outgoingAppBuffers.poll()?.let {
  133. byteBufferPool.recycle(it)
  134. } ?: break
  135. }
  136. outgoingAppBuffers.close()
  137. }
  138. private suspend fun readLoop() {
  139. while (socket.isOpen) {
  140. sslLoop(unwrap())
  141. }
  142. }
  143. private suspend fun writeLoop() {
  144. while (socket.isOpen) {
  145. sslLoop(wrap())
  146. }
  147. }
  148. }
  149. internal fun X509Certificate.validFor(host: String): Boolean {
  150. val hostParts = host.split('.')
  151. return allNames
  152. .map { it.split('.') }
  153. .filter { it.size == hostParts.size }
  154. .filter { it[0].wildCardMatches(hostParts[0]) }
  155. .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
  156. }
  157. private fun String.wildCardMatches(host: String) =
  158. count { it == '*' } <= 1 &&
  159. host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
  160. private val X509Certificate.allNames: Sequence<String>
  161. get() = sequence {
  162. commonName?.let { yield(it) }
  163. yieldAll(subjectAlternateNames)
  164. }
  165. private val X509Certificate.subjectAlternateNames: Set<String>
  166. get() = nullOnThrow {
  167. subjectAlternativeNames
  168. ?.filter { it[0] == 2 }
  169. ?.map { it[1].toString() }
  170. ?.toSet()
  171. } ?: emptySet()
  172. private val X509Certificate.commonName: String?
  173. get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
  174. private val X509Certificate.rdns: Map<String, List<Rdn>>
  175. get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
  176. private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
  177. block()
  178. } catch (ex: Throwable) {
  179. null
  180. }