123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- package com.dmdirc.ktirc.io
-
- import com.dmdirc.ktirc.util.logger
- import kotlinx.coroutines.CoroutineScope
- import kotlinx.coroutines.channels.Channel
- import kotlinx.coroutines.channels.ClosedReceiveChannelException
- import kotlinx.coroutines.io.ByteChannel
- import kotlinx.coroutines.io.ByteWriteChannel
- import kotlinx.coroutines.launch
- import java.net.SocketAddress
- import java.nio.ByteBuffer
- import java.security.cert.CertificateException
- import java.security.cert.X509Certificate
- import java.util.regex.Pattern
- import javax.naming.ldap.LdapName
- import javax.naming.ldap.Rdn
- import javax.net.ssl.SSLContext
- import javax.net.ssl.SSLEngine
- import javax.net.ssl.SSLEngineResult
-
-
- internal class TlsSocket(
- private val scope: CoroutineScope,
- private val socket: Socket,
- private val sslContext: SSLContext,
- private val hostname: String
- ) : Socket {
-
- private val log by logger()
- private var engine: SSLEngine = sslContext.createSSLEngine()
-
- private var incomingNetBuffer = ByteBuffer.allocate(0)
- private var incomingAppBuffer = ByteBuffer.allocate(0)
- private var outgoingAppBuffers = Channel<ByteBuffer>(capacity = Channel.UNLIMITED)
-
- private var writeChannel = ByteChannel(autoFlush = true)
-
- override val write: ByteWriteChannel
- get() = writeChannel
-
- override val isOpen: Boolean
- get() = socket.isOpen
-
- override fun bind(socketAddress: SocketAddress) {
- socket.bind(socketAddress)
- }
-
- override suspend fun connect(socketAddress: SocketAddress) {
- writeChannel = ByteChannel(autoFlush = true)
-
- engine = sslContext.createSSLEngine().apply {
- useClientMode = true
- }
-
- incomingNetBuffer = ByteBuffer.allocate(engine.session.packetBufferSize)
- outgoingAppBuffers = Channel(capacity = Channel.UNLIMITED)
- incomingAppBuffer = ByteBuffer.allocate(engine.session.applicationBufferSize)
-
- socket.connect(socketAddress)
-
- engine.beginHandshake()
-
- sslLoop()
- }
-
- private suspend fun sslLoop(initialResult: SSLEngineResult? = null) {
- var result: SSLEngineResult? = initialResult
- var handshakeStatus = result?.handshakeStatus ?: engine.handshakeStatus
- while (true) {
- when (handshakeStatus) {
- SSLEngineResult.HandshakeStatus.NEED_TASK -> {
- engine.delegatedTask.run()
- handshakeStatus = engine.handshakeStatus
- }
- SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
- result = wrap()
- handshakeStatus = result?.handshakeStatus
- }
-
- SSLEngineResult.HandshakeStatus.NEED_UNWRAP -> {
- result = unwrap()
- handshakeStatus = result?.handshakeStatus
- }
-
- SSLEngineResult.HandshakeStatus.FINISHED -> {
- val certs = engine.session.peerCertificates
- if (certs.isEmpty() || (certs[0] as? X509Certificate)?.validFor(hostname) == false) {
- throw CertificateException("Certificate is not valid for $hostname")
- }
- scope.launch { readLoop() }
- scope.launch { writeLoop() }
- return
- }
-
- else -> return
- }
- }
- }
-
- override suspend fun read(buffer: ByteBuffer) = try {
- val nextBuffer = outgoingAppBuffers.receive()
- val bytes = nextBuffer.limit()
- buffer.put(nextBuffer)
- defaultPool.recycle(nextBuffer)
- bytes
- } catch (_: ClosedReceiveChannelException) {
- -1
- }
-
- private suspend fun wrap(): SSLEngineResult? {
- var result: SSLEngineResult? = null
- defaultPool.borrow { netBuffer ->
- if (engine.handshakeStatus <= SSLEngineResult.HandshakeStatus.FINISHED) {
- writeChannel.readAvailable(incomingAppBuffer)
- }
- incomingAppBuffer.flip()
- result = engine.wrap(incomingAppBuffer, netBuffer)
- incomingAppBuffer.compact()
-
- netBuffer.flip()
- socket.write.writeFully(netBuffer)
- }
- return result
- }
-
- private suspend fun unwrap(networkRead: Boolean = incomingNetBuffer.position() == 0): SSLEngineResult? {
- if (networkRead) {
- val bytes = socket.read(incomingNetBuffer.slice())
- if (bytes == -1) {
- close()
- return null
- }
- incomingNetBuffer.position(incomingNetBuffer.position() + bytes)
- }
-
- incomingNetBuffer.flip()
-
- val buffer = defaultPool.borrow()
- val result = engine.unwrap(incomingNetBuffer, buffer)
- incomingNetBuffer.compact()
- if (buffer.position() > 0) {
- buffer.flip()
- outgoingAppBuffers.send(buffer)
- } else {
- defaultPool.recycle(buffer)
- }
-
- return if (result?.status == SSLEngineResult.Status.BUFFER_UNDERFLOW && !networkRead) {
- // We didn't do a network read, but SSLEngine is unhappy; force a read.
- log.finest { "Incoming net buffer underflowed, forcing re-read" }
- unwrap(true)
- } else {
- result
- }
- }
-
- override fun close() {
- socket.close()
-
- // Release any buffers we've got queued up
- while(true) {
- outgoingAppBuffers.poll()?.let {
- defaultPool.recycle(it)
- } ?: break
- }
-
- outgoingAppBuffers.close()
- }
-
- private suspend fun readLoop() {
- while (socket.isOpen) {
- sslLoop(unwrap())
- }
- }
-
- private suspend fun writeLoop() {
- while (socket.isOpen) {
- sslLoop(wrap())
- }
- }
-
- }
-
- internal fun X509Certificate.validFor(host: String): Boolean {
- val hostParts = host.split('.')
- return allNames
- .map { it.split('.') }
- .filter { it.size == hostParts.size }
- .filter { it[0].wildCardMatches(hostParts[0]) }
- .any { it.zip(hostParts).slice(1 until hostParts.size).all { (part, host) -> part.equals(host, ignoreCase = true) } }
- }
-
- private fun String.wildCardMatches(host: String) =
- count { it == '*' } <= 1 &&
- host.matches(Regex(split('*').joinToString(".*") { Pattern.quote(it) }, RegexOption.IGNORE_CASE))
-
- private val X509Certificate.allNames: Sequence<String>
- get() = sequence {
- commonName?.let { yield(it) }
- yieldAll(subjectAlternateNames)
- }
-
- private val X509Certificate.subjectAlternateNames: Set<String>
- get() = nullOnThrow {
- subjectAlternativeNames
- ?.filter { it[0] == 2 }
- ?.map { it[1].toString() }
- ?.toSet()
- } ?: emptySet()
-
- private val X509Certificate.commonName: String?
- get() = nullOnThrow { rdns["CN"]?.firstOrNull()?.value?.toString() }
-
- private val X509Certificate.rdns: Map<String, List<Rdn>>
- get() = LdapName(subjectX500Principal.name).rdns.groupBy { it.type.toUpperCase() }
-
- private inline fun <S> nullOnThrow(block: () -> S?): S? = try {
- block()
- } catch (ex: Throwable) {
- null
- }
|