You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tls.kt 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package com.dmdirc.ktirc.io
  2. import com.dmdirc.ktirc.util.logger
  3. import kotlinx.coroutines.CoroutineScope
  4. import kotlinx.coroutines.channels.Channel
  5. import kotlinx.coroutines.io.ByteChannel
  6. import kotlinx.coroutines.io.ByteWriteChannel
  7. import kotlinx.coroutines.launch
  8. import java.net.SocketAddress
  9. import java.nio.ByteBuffer
  10. import java.security.cert.CertificateException
  11. import java.security.cert.X509Certificate
  12. import java.util.regex.Pattern
  13. import javax.naming.ldap.LdapName
  14. import javax.naming.ldap.Rdn
  15. import javax.net.ssl.SSLContext
  16. import javax.net.ssl.SSLEngine
  17. import javax.net.ssl.SSLEngineResult
  18. internal class TlsSocket(
  19. private val scope: CoroutineScope,
  20. private val socket: Socket,
  21. private val sslContext: SSLContext,
  22. private val hostname: String
  23. ) : Socket {
  24. private val log by logger()
  25. private var engine: SSLEngine = sslContext.createSSLEngine()
  26. private var incomingNetBuffer = ByteBuffer.allocate(0)
  27. private var incomingAppBuffer = ByteBuffer.allocate(0)
  28. private var outgoingAppBuffers = Channel<ByteBuffer>(capacity = Channel.UNLIMITED)
  29. private var writeChannel = ByteChannel(autoFlush = true)
  30. override val write: ByteWriteChannel
  31. get() = writeChannel
  32. override val isOpen: Boolean
  33. get() = socket.isOpen
  34. override fun bind(socketAddress: SocketAddress) {
  35. socket.bind(socketAddress)
  36. }
  37. override suspend fun connect(socketAddress: SocketAddress) {
  38. writeChannel = ByteChannel(autoFlush = true)
  39. engine = sslContext.createSSLEngine().apply {
  40. useClientMode = true
  41. }
  42. incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
  43. outgoingAppBuffers = Channel(capacity = Channel.UNLIMITED)
  44. incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
  45. socket.connect(socketAddress)
  46. engine.beginHandshake()
  47. sslLoop()
  48. }
  49. private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
  50. var result: SSLEngineResult? = initialResult
  51. var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
  52. while (true) {
  53. when (handshakeStatus) {
  54. SSLEngineResult.HandshakeStatus.NEED_TASK -> {
  55. engine.delegatedTask.run()
  56. handshakeStatus = engine.handshakeStatus
  57. }
  58. SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
  59. result = wrap()
  60. handshakeStatus = result?.handshakeStatus
  61. }
  62. SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
  63. result = unwrap()
  64. handshakeStatus = result?.handshakeStatus
  65. }
  66. SSLEngineResult.HandshakeStatus.FINISHED -> {
  67. val certs = engine.session.peerCertificates
  68. if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
  69. throw CertificateException("Certificate is not valid for $hostname")
  70. }
  71. scope.launch { readLoop() }
  72. scope.launch { writeLoop() }
  73. return
  74. }
  75. else -> return
  76. }
  77. }
  78. }
  79. override suspend fun read(buffer: ByteBuffer): Int {
  80. val nextBuffer = outgoingAppBuffers.receive()
  81. val bytes = nextBuffer.limit()
  82. buffer.put(nextBuffer)
  83. defaultPool.recycle(nextBuffer)
  84. return bytes
  85. }
  86. private suspend fun wrap(): SSLEngineResult? {
  87. var result: SSLEngineResult? = null
  88. defaultPool.borrow { netBuffer ->
  89. if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
  90. writeChannel.readAvailable(incomingAppBuffer)
  91. }
  92. incomingAppBuffer.flip()
  93. result = engine.wrap(incomingAppBuffer, netBuffer)
  94. incomingAppBuffer.compact()
  95. netBuffer.flip()
  96. socket.write.writeFully(netBuffer)
  97. }
  98. return result
  99. }
  100. private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
  101. if (networkRead) {
  102. val bytes = socket.read(incomingNetBuffer.slice())
  103. if (bytes == -1) {
  104. close()
  105. return null
  106. }
  107. incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
  108. }
  109. incomingNetBuffer.flip()
  110. val buffer = defaultPool.borrow()
  111. val result = engine.unwrap(incomingNetBuffer, buffer)
  112. incomingNetBuffer.compact()
  113. if (buffer.position() > 0) {
  114. buffer.flip()
  115. outgoingAppBuffers.send(buffer)
  116. } else {
  117. defaultPool.recycle(buffer)
  118. }
  119. return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
  120. // We didn't do a network read, but SSLEngine is unhappy; force a read.
  121. log.finest { "Incoming net buffer underflowed, forcing re-read" }
  122. unwrap(true)
  123. } else {
  124. result
  125. }
  126. }
  127. override fun close() {
  128. socket.close()
  129. outgoingAppBuffers.close()
  130. }
  131. private suspend fun readLoop() {
  132. while (socket.isOpen) {
  133. sslLoop(unwrap())
  134. }
  135. }
  136. private suspend fun writeLoop() {
  137. while (socket.isOpen) {
  138. sslLoop(wrap())
  139. }
  140. }
  141. }
  142. internal fun X509Certificate.validFor(host: String): Boolean {
  143. val hostParts = host.split('.')
  144. return allNames
  145. .map { it.split('.') }
  146. .filter { it.size == hostParts.size }
  147. .filter { it[0].wildCardMatches(hostParts[0]) }
  148. .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
  149. }
  150. private fun String.wildCardMatches(host: String) =
  151. count { it == '*' } <= 1 &&
  152. host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
  153. private val X509Certificate.allNames: Sequence<String>
  154. get() = sequence {
  155. commonName?.let { yield(it) }
  156. yieldAll(subjectAlternateNames)
  157. }
  158. private val X509Certificate.subjectAlternateNames: Set<String>
  159. get() = nullOnThrow {
  160. subjectAlternativeNames
  161. ?.filter { it[0] == 2 }
  162. ?.map { it[1].toString() }
  163. ?.toSet()
  164. } ?: emptySet()
  165. private val X509Certificate.commonName: String?
  166. get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
  167. private val X509Certificate.rdns: Map<String, List<Rdn>>
  168. get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
  169. private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
  170. block()
  171. } catch (ex: Throwable) {
  172. null
  173. }