You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tls.kt 7.3KB


  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import kotlinx.coroutines.CoroutineScope
  4. import kotlinx.coroutines.channels.Channel
  5. import kotlinx.coroutines.channels.ClosedReceiveChannelException
  6. import kotlinx.coroutines.io.ByteChannel
  7. import kotlinx.coroutines.io.ByteWriteChannel
  8. import kotlinx.coroutines.launch
  9. import java.net.SocketAddress
  10. import java.nio.ByteBuffer
  11. import java.security.cert.CertificateException
  12. import java.security.cert.X509Certificate
  13. import java.util.regex.Pattern
  14. import javax.naming.ldap.LdapName
  15. import javax.naming.ldap.Rdn
  16. import javax.net.ssl.SSLContext
  17. import javax.net.ssl.SSLEngine
  18. import javax.net.ssl.SSLEngineResult
  19. internal class TlsSocket(
  20. private val scope: CoroutineScope,
  21. private val socket: Socket,
  22. private val sslContext: SSLContext,
  23. private val hostname: String
  24. ) : Socket {
  25. private val log by logger()
  26. private var engine: SSLEngine = sslContext.createSSLEngine()
  27. private var incomingNetBuffer = ByteBuffer.allocate(0)
  28. private var incomingAppBuffer = ByteBuffer.allocate(0)
  29. private var outgoingAppBuffers = Channel<ByteBuffer>(capacity = Channel.UNLIMITED)
  30. private var writeChannel = ByteChannel(autoFlush = true)
  31. override val write: ByteWriteChannel
  32. get() = writeChannel
  33. override val isOpen: Boolean
  34. get() = socket.isOpen
  35. override fun bind(socketAddress: SocketAddress) {
  36. socket.bind(socketAddress)
  37. }
  38. override suspend fun connect(socketAddress: SocketAddress) {
  39. writeChannel = ByteChannel(autoFlush = true)
  40. engine = sslContext.createSSLEngine().apply {
  41. useClientMode = true
  42. }
  43. incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
  44. outgoingAppBuffers = Channel(capacity = Channel.UNLIMITED)
  45. incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
  46. socket.connect(socketAddress)
  47. engine.beginHandshake()
  48. sslLoop()
  49. }
  50. private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
  51. var result: SSLEngineResult? = initialResult
  52. var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
  53. while (true) {
  54. when (handshakeStatus) {
  55. SSLEngineResult.HandshakeStatus.NEED_TASK -> {
  56. engine.delegatedTask.run()
  57. handshakeStatus = engine.handshakeStatus
  58. }
  59. SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
  60. result = wrap()
  61. handshakeStatus = result?.handshakeStatus
  62. }
  63. SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
  64. result = unwrap()
  65. handshakeStatus = result?.handshakeStatus
  66. }
  67. SSLEngineResult.HandshakeStatus.FINISHED -> {
  68. val certs = engine.session.peerCertificates
  69. if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
  70. throw CertificateException("Certificate is not valid for $hostname")
  71. }
  72. scope.launch { readLoop() }
  73. scope.launch { writeLoop() }
  74. return
  75. }
  76. else -> return
  77. }
  78. }
  79. }
  80. override suspend fun read(buffer: ByteBuffer) = try {
  81. val nextBuffer = outgoingAppBuffers.receive()
  82. val bytes = nextBuffer.limit()
  83. buffer.put(nextBuffer)
  84. defaultPool.recycle(nextBuffer)
  85. bytes
  86. } catch (_: ClosedReceiveChannelException) {
  87. -1
  88. }
  89. private suspend fun wrap(): SSLEngineResult? {
  90. var result: SSLEngineResult? = null
  91. defaultPool.borrow { netBuffer ->
  92. if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
  93. writeChannel.readAvailable(incomingAppBuffer)
  94. }
  95. incomingAppBuffer.flip()
  96. result = engine.wrap(incomingAppBuffer, netBuffer)
  97. incomingAppBuffer.compact()
  98. netBuffer.flip()
  99. socket.write.writeFully(netBuffer)
  100. }
  101. return result
  102. }
  103. private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
  104. if (networkRead) {
  105. val bytes = socket.read(incomingNetBuffer.slice())
  106. if (bytes == -1) {
  107. close()
  108. return null
  109. }
  110. incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
  111. }
  112. incomingNetBuffer.flip()
  113. val buffer = defaultPool.borrow()
  114. val result = engine.unwrap(incomingNetBuffer, buffer)
  115. incomingNetBuffer.compact()
  116. if (buffer.position() > 0) {
  117. buffer.flip()
  118. outgoingAppBuffers.send(buffer)
  119. } else {
  120. defaultPool.recycle(buffer)
  121. }
  122. return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
  123. // We didn't do a network read, but SSLEngine is unhappy; force a read.
  124. log.finest { "Incoming net buffer underflowed, forcing re-read" }
  125. unwrap(true)
  126. } else {
  127. result
  128. }
  129. }
  130. override fun close() {
  131. socket.close()
  132. // Release any buffers we've got queued up
  133. while(true) {
  134. outgoingAppBuffers.poll()?.let {
  135. defaultPool.recycle(it)
  136. } ?: break
  137. }
  138. outgoingAppBuffers.close()
  139. }
  140. private suspend fun readLoop() {
  141. while (socket.isOpen) {
  142. sslLoop(unwrap())
  143. }
  144. }
  145. private suspend fun writeLoop() {
  146. while (socket.isOpen) {
  147. sslLoop(wrap())
  148. }
  149. }
  150. }
  151. internal fun X509Certificate.validFor(host: String): Boolean {
  152. val hostParts = host.split('.')
  153. return allNames
  154. .map { it.split('.') }
  155. .filter { it.size == hostParts.size }
  156. .filter { it[0].wildCardMatches(hostParts[0]) }
  157. .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
  158. }
  159. private fun String.wildCardMatches(host: String) =
  160. count { it == '*' } <= 1 &&
  161. host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
  162. private val X509Certificate.allNames: Sequence<String>
  163. get() = sequence {
  164. commonName?.let { yield(it) }
  165. yieldAll(subjectAlternateNames)
  166. }
  167. private val X509Certificate.subjectAlternateNames: Set<String>
  168. get() = nullOnThrow {
  169. subjectAlternativeNames
  170. ?.filter { it[0] == 2 }
  171. ?.map { it[1].toString() }
  172. ?.toSet()
  173. } ?: emptySet()
  174. private val X509Certificate.commonName: String?
  175. get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
  176. private val X509Certificate.rdns: Map<String, List<Rdn>>
  177. get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
  178. private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
  179. block()
  180. } catch (ex: Throwable) {
  181. null
  182. }