You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TlsTest.kt 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package com.dmdirc.ktirc.io
  2. import io.mockk.every
  3. import io.mockk.mockk
  4. import kotlinx.coroutines.GlobalScope
  5. import kotlinx.coroutines.async
  6. import kotlinx.coroutines.io.writeFully
  7. import kotlinx.coroutines.launch
  8. import kotlinx.coroutines.runBlocking
  9. import kotlinx.io.core.String
  10. import org.junit.jupiter.api.Assertions
  11. import org.junit.jupiter.api.Assertions.*
  12. import org.junit.jupiter.api.Test
  13. import org.junit.jupiter.api.parallel.Execution
  14. import org.junit.jupiter.api.parallel.ExecutionMode
  15. import java.net.InetSocketAddress
  16. import java.net.ServerSocket
  17. import java.security.KeyStore
  18. import java.security.cert.CertificateException
  19. import java.security.cert.X509Certificate
  20. import javax.net.ssl.KeyManagerFactory
  21. import javax.net.ssl.SSLContext
  22. import javax.net.ssl.X509TrustManager
  23. internal class CertificateValidationTest {
  24. private val cert = mockk<X509Certificate>()
  25. @Test
  26. fun `checks common name`() {
  27. every { cert.subjectX500Principal } returns mockk {
  28. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  29. }
  30. assertTrue(cert.validFor("subdomain.test.ktirc"))
  31. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  32. assertFalse(cert.validFor("testing"))
  33. }
  34. @Test
  35. fun `checks common name with suffixed wildcard`() {
  36. every { cert.subjectX500Principal } returns mockk {
  37. every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
  38. }
  39. assertTrue(cert.validFor("subdomain.test.ktirc"))
  40. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  41. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  42. assertFalse(cert.validFor("1subdomain.test.ktirc"))
  43. }
  44. @Test
  45. fun `checks common name with preixed wildcard`() {
  46. every { cert.subjectX500Principal } returns mockk {
  47. every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
  48. }
  49. assertTrue(cert.validFor("subdomain.test.ktirc"))
  50. assertTrue(cert.validFor("1subdomain.test.ktirc"))
  51. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  52. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  53. }
  54. @Test
  55. fun `checks common name with infixed wildcard`() {
  56. every { cert.subjectX500Principal } returns mockk {
  57. every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
  58. }
  59. assertTrue(cert.validFor("subdomain.test.ktirc"))
  60. assertTrue(cert.validFor("SUB-domain.test.ktirc"))
  61. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  62. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  63. }
  64. @Test
  65. fun `ignores wildcards in CN if they're not left-most`() {
  66. every { cert.subjectX500Principal } returns mockk {
  67. every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
  68. }
  69. assertFalse(cert.validFor("foo.domain.test.ktirc"))
  70. assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
  71. assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
  72. }
  73. @Test
  74. fun `ignores wildcards in CN if there are too many`() {
  75. every { cert.subjectX500Principal } returns mockk {
  76. every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
  77. }
  78. assertFalse(cert.validFor("domain.test.ktirc"))
  79. assertFalse(cert.validFor("subdomain.test.ktirc"))
  80. assertFalse(cert.validFor("domain1.test.ktirc"))
  81. }
  82. @Test
  83. fun `checks all sans`() {
  84. every { cert.subjectAlternativeNames } returns listOf(
  85. listOf(4, "directory.test.ktirc"),
  86. listOf(2, "subdomain1.test.ktirc"),
  87. listOf(2, "subdomain2.test.ktirc"),
  88. listOf(2, "subdomain3.test.ktirc")
  89. )
  90. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  91. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  92. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  93. assertFalse(cert.validFor("directory.test.ktirc"))
  94. }
  95. @Test
  96. fun `checks wildcard sans`() {
  97. every { cert.subjectAlternativeNames } returns listOf(
  98. listOf(4, "directory.test.ktirc"),
  99. listOf(2, "*domain1.test.ktirc"),
  100. listOf(2, "subdomain*.test.ktirc"),
  101. listOf(2, "*foo*.test.ktirc"),
  102. listOf(2, "foo.*.ktirc")
  103. )
  104. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  105. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  106. assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
  107. assertFalse(cert.validFor("foo.test.ktirc"))
  108. }
  109. @Test
  110. fun `still uses CN if sans throws`() {
  111. every { cert.subjectX500Principal } returns mockk {
  112. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  113. }
  114. every { cert.subjectAlternativeNames } throws CertificateException("Oops")
  115. assertTrue(cert.validFor("subdomain.test.ktirc"))
  116. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  117. assertFalse(cert.validFor("testing"))
  118. }
  119. @Test
  120. fun `still uses sans if CN throws`() {
  121. every { cert.subjectX500Principal } throws CertificateException("Oops")
  122. every { cert.subjectAlternativeNames } returns listOf(
  123. listOf(4, "directory.test.ktirc"),
  124. listOf(2, "subdomain1.test.ktirc"),
  125. listOf(2, "subdomain2.test.ktirc"),
  126. listOf(2, "subdomain3.test.ktirc")
  127. )
  128. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  129. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  130. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  131. assertFalse(cert.validFor("directory.test.ktirc"))
  132. }
  133. @Test
  134. fun `fails if CN and sans missing`() {
  135. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  136. assertFalse(cert.validFor("subdomain2.test.KTIRC"))
  137. assertFalse(cert.validFor("subdomain3.test.ktirc"))
  138. assertFalse(cert.validFor("directory.test.ktirc"))
  139. }
  140. }
  141. @Suppress("BlockingMethodInNonBlockingContext")
  142. @Execution(ExecutionMode.SAME_THREAD)
  143. internal class TlsSocketTest {
  144. @Test
  145. fun `can send a string to a server over TLS`() = runBlocking {
  146. tlsServerSocket(12321).use { serverSocket ->
  147. val plainSocket = PlainTextSocket(GlobalScope)
  148. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  149. val clientBytesAsync = GlobalScope.async {
  150. ByteArray(13).apply {
  151. serverSocket.accept().getInputStream().read(this)
  152. }
  153. }
  154. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  155. tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
  156. val bytes = clientBytesAsync.await()
  157. Assertions.assertNotNull(bytes)
  158. Assertions.assertEquals("Hello World\r\n", String(bytes))
  159. }
  160. }
  161. @Test
  162. fun `throws if the hostname mismatches`() {
  163. tlsServerSocket(12321).use { serverSocket ->
  164. val plainSocket = PlainTextSocket(GlobalScope)
  165. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
  166. GlobalScope.launch {
  167. serverSocket.accept().getInputStream().read()
  168. }
  169. runBlocking {
  170. try {
  171. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  172. fail<Unit>("Expected an exception")
  173. } catch (ex: Exception) {
  174. assertTrue(ex is CertificateException)
  175. }
  176. }
  177. }
  178. }
  179. }
  180. internal fun tlsServerSocket(port: Int): ServerSocket {
  181. val keyStore = KeyStore.getInstance("PKCS12")
  182. keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
  183. val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
  184. keyManagerFactory.init(keyStore, CharArray(0))
  185. val sslContext = SSLContext.getInstance("TLSv1.2")
  186. sslContext.init(keyManagerFactory.keyManagers, null, null)
  187. return sslContext.serverSocketFactory.createServerSocket(port)
  188. }
  189. internal fun getTrustingContext() =
  190. SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
  191. internal fun getTrustingManager() = object : X509TrustManager {
  192. override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
  193. override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
  194. override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
  195. }