You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TlsTest.kt 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package com.dmdirc.ktirc.io
  2. import io.mockk.every
  3. import io.mockk.mockk
  4. import kotlinx.coroutines.GlobalScope
  5. import kotlinx.coroutines.async
  6. import kotlinx.coroutines.io.writeFully
  7. import kotlinx.coroutines.launch
  8. import kotlinx.coroutines.runBlocking
  9. import kotlinx.io.core.String
  10. import org.junit.jupiter.api.Assertions.*
  11. import org.junit.jupiter.api.Test
  12. import org.junit.jupiter.api.parallel.Execution
  13. import org.junit.jupiter.api.parallel.ExecutionMode
  14. import java.net.InetSocketAddress
  15. import java.net.ServerSocket
  16. import java.security.KeyStore
  17. import java.security.cert.CertificateException
  18. import java.security.cert.X509Certificate
  19. import javax.net.ssl.KeyManagerFactory
  20. import javax.net.ssl.SSLContext
  21. import javax.net.ssl.X509TrustManager
  22. internal class CertificateValidationTest {
  23. private val cert = mockk<X509Certificate>()
  24. @Test
  25. fun `checks common name`() {
  26. every { cert.subjectX500Principal } returns mockk {
  27. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  28. }
  29. assertTrue(cert.validFor("subdomain.test.ktirc"))
  30. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  31. assertFalse(cert.validFor("testing"))
  32. }
  33. @Test
  34. fun `checks common name with suffixed wildcard`() {
  35. every { cert.subjectX500Principal } returns mockk {
  36. every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
  37. }
  38. assertTrue(cert.validFor("subdomain.test.ktirc"))
  39. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  40. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  41. assertFalse(cert.validFor("1subdomain.test.ktirc"))
  42. }
  43. @Test
  44. fun `checks common name with preixed wildcard`() {
  45. every { cert.subjectX500Principal } returns mockk {
  46. every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
  47. }
  48. assertTrue(cert.validFor("subdomain.test.ktirc"))
  49. assertTrue(cert.validFor("1subdomain.test.ktirc"))
  50. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  51. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  52. }
  53. @Test
  54. fun `checks common name with infixed wildcard`() {
  55. every { cert.subjectX500Principal } returns mockk {
  56. every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
  57. }
  58. assertTrue(cert.validFor("subdomain.test.ktirc"))
  59. assertTrue(cert.validFor("SUB-domain.test.ktirc"))
  60. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  61. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  62. }
  63. @Test
  64. fun `ignores wildcards in CN if they're not left-most`() {
  65. every { cert.subjectX500Principal } returns mockk {
  66. every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
  67. }
  68. assertFalse(cert.validFor("foo.domain.test.ktirc"))
  69. assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
  70. assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
  71. }
  72. @Test
  73. fun `ignores wildcards in CN if there are too many`() {
  74. every { cert.subjectX500Principal } returns mockk {
  75. every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
  76. }
  77. assertFalse(cert.validFor("domain.test.ktirc"))
  78. assertFalse(cert.validFor("subdomain.test.ktirc"))
  79. assertFalse(cert.validFor("domain1.test.ktirc"))
  80. }
  81. @Test
  82. fun `checks all sans`() {
  83. every { cert.subjectAlternativeNames } returns listOf(
  84. listOf(4, "directory.test.ktirc"),
  85. listOf(2, "subdomain1.test.ktirc"),
  86. listOf(2, "subdomain2.test.ktirc"),
  87. listOf(2, "subdomain3.test.ktirc")
  88. )
  89. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  90. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  91. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  92. assertFalse(cert.validFor("directory.test.ktirc"))
  93. }
  94. @Test
  95. fun `checks wildcard sans`() {
  96. every { cert.subjectAlternativeNames } returns listOf(
  97. listOf(4, "directory.test.ktirc"),
  98. listOf(2, "*domain1.test.ktirc"),
  99. listOf(2, "subdomain*.test.ktirc"),
  100. listOf(2, "*foo*.test.ktirc"),
  101. listOf(2, "foo.*.ktirc")
  102. )
  103. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  104. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  105. assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
  106. assertFalse(cert.validFor("foo.test.ktirc"))
  107. }
  108. @Test
  109. fun `still uses CN if sans throws`() {
  110. every { cert.subjectX500Principal } returns mockk {
  111. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  112. }
  113. every { cert.subjectAlternativeNames } throws CertificateException("Oops")
  114. assertTrue(cert.validFor("subdomain.test.ktirc"))
  115. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  116. assertFalse(cert.validFor("testing"))
  117. }
  118. @Test
  119. fun `still uses sans if CN throws`() {
  120. every { cert.subjectX500Principal } throws CertificateException("Oops")
  121. every { cert.subjectAlternativeNames } returns listOf(
  122. listOf(4, "directory.test.ktirc"),
  123. listOf(2, "subdomain1.test.ktirc"),
  124. listOf(2, "subdomain2.test.ktirc"),
  125. listOf(2, "subdomain3.test.ktirc")
  126. )
  127. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  128. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  129. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  130. assertFalse(cert.validFor("directory.test.ktirc"))
  131. }
  132. @Test
  133. fun `fails if CN and sans missing`() {
  134. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  135. assertFalse(cert.validFor("subdomain2.test.KTIRC"))
  136. assertFalse(cert.validFor("subdomain3.test.ktirc"))
  137. assertFalse(cert.validFor("directory.test.ktirc"))
  138. }
  139. }
  140. @Suppress("BlockingMethodInNonBlockingContext")
  141. @Execution(ExecutionMode.SAME_THREAD)
  142. internal class TlsSocketTest {
  143. @Test
  144. fun `can send a string to a server over TLS`() = runBlocking {
  145. tlsServerSocket(12321).use { serverSocket ->
  146. val plainSocket = PlainTextSocket(GlobalScope)
  147. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  148. val clientBytesAsync = GlobalScope.async {
  149. ByteArray(13).apply {
  150. serverSocket.accept().getInputStream().read(this)
  151. }
  152. }
  153. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  154. tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
  155. val bytes = clientBytesAsync.await()
  156. assertNotNull(bytes)
  157. assertEquals("Hello World\r\n", String(bytes))
  158. }
  159. }
  160. @Test
  161. fun `can read a string from a server over TLS`() = runBlocking<Unit> {
  162. tlsServerSocket(12321).use { serverSocket ->
  163. val plainSocket = PlainTextSocket(GlobalScope)
  164. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  165. val socket = GlobalScope.async {
  166. serverSocket.accept().apply {
  167. GlobalScope.launch {
  168. getInputStream().read()
  169. }
  170. }
  171. }
  172. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  173. GlobalScope.launch {
  174. with (socket.await().getOutputStream()) {
  175. write("Hack the planet!".toByteArray())
  176. flush()
  177. }
  178. }
  179. val buffer = tlsSocket.read()
  180. assertNotNull(buffer)
  181. buffer?.let {
  182. assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
  183. }
  184. }
  185. }
  186. @Test
  187. fun `read returns null after close`() = runBlocking {
  188. tlsServerSocket(12321).use { serverSocket ->
  189. val plainSocket = PlainTextSocket(GlobalScope)
  190. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  191. GlobalScope.launch {
  192. serverSocket.accept().getInputStream().read()
  193. }
  194. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  195. tlsSocket.close()
  196. val buffer = tlsSocket.read()
  197. assertNull(buffer)
  198. }
  199. }
  200. @Test
  201. fun `throws if the hostname mismatches`() {
  202. tlsServerSocket(12321).use { serverSocket ->
  203. val plainSocket = PlainTextSocket(GlobalScope)
  204. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
  205. GlobalScope.launch {
  206. serverSocket.accept().getInputStream().read()
  207. }
  208. runBlocking {
  209. try {
  210. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  211. fail<Unit>("Expected an exception")
  212. } catch (ex: Exception) {
  213. assertTrue(ex is CertificateException)
  214. }
  215. }
  216. }
  217. }
  218. }
  219. internal fun tlsServerSocket(port: Int): ServerSocket {
  220. val keyStore = KeyStore.getInstance("PKCS12")
  221. keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
  222. val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
  223. keyManagerFactory.init(keyStore, CharArray(0))
  224. val sslContext = SSLContext.getInstance("TLSv1.2")
  225. sslContext.init(keyManagerFactory.keyManagers, null, null)
  226. return sslContext.serverSocketFactory.createServerSocket(port)
  227. }
  228. internal fun getTrustingContext() =
  229. SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
  230. internal fun getTrustingManager() = object : X509TrustManager {
  231. override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
  232. override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
  233. override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
  234. }