You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TlsTest.kt 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package com.dmdirc.ktirc.io
  2. import io.mockk.every
  3. import io.mockk.mockk
  4. import kotlinx.coroutines.*
  5. import kotlinx.coroutines.io.writeFully
  6. import kotlinx.io.core.String
  7. import org.junit.jupiter.api.Assertions.*
  8. import org.junit.jupiter.api.Test
  9. import org.junit.jupiter.api.parallel.Execution
  10. import org.junit.jupiter.api.parallel.ExecutionMode
  11. import java.net.InetSocketAddress
  12. import java.net.ServerSocket
  13. import java.security.KeyStore
  14. import java.security.cert.CertificateException
  15. import java.security.cert.X509Certificate
  16. import javax.net.ssl.KeyManagerFactory
  17. import javax.net.ssl.SSLContext
  18. import javax.net.ssl.X509TrustManager
  19. internal class CertificateValidationTest {
  20. private val cert = mockk<X509Certificate>()
  21. @Test
  22. fun `checks common name`() {
  23. every { cert.subjectX500Principal } returns mockk {
  24. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  25. }
  26. assertTrue(cert.validFor("subdomain.test.ktirc"))
  27. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  28. assertFalse(cert.validFor("testing"))
  29. }
  30. @Test
  31. fun `checks common name with suffixed wildcard`() {
  32. every { cert.subjectX500Principal } returns mockk {
  33. every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
  34. }
  35. assertTrue(cert.validFor("subdomain.test.ktirc"))
  36. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  37. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  38. assertFalse(cert.validFor("1subdomain.test.ktirc"))
  39. }
  40. @Test
  41. fun `checks common name with preixed wildcard`() {
  42. every { cert.subjectX500Principal } returns mockk {
  43. every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
  44. }
  45. assertTrue(cert.validFor("subdomain.test.ktirc"))
  46. assertTrue(cert.validFor("1subdomain.test.ktirc"))
  47. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  48. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  49. }
  50. @Test
  51. fun `checks common name with infixed wildcard`() {
  52. every { cert.subjectX500Principal } returns mockk {
  53. every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
  54. }
  55. assertTrue(cert.validFor("subdomain.test.ktirc"))
  56. assertTrue(cert.validFor("SUB-domain.test.ktirc"))
  57. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  58. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  59. }
  60. @Test
  61. fun `ignores wildcards in CN if they're not left-most`() {
  62. every { cert.subjectX500Principal } returns mockk {
  63. every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
  64. }
  65. assertFalse(cert.validFor("foo.domain.test.ktirc"))
  66. assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
  67. assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
  68. }
  69. @Test
  70. fun `ignores wildcards in CN if there are too many`() {
  71. every { cert.subjectX500Principal } returns mockk {
  72. every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
  73. }
  74. assertFalse(cert.validFor("domain.test.ktirc"))
  75. assertFalse(cert.validFor("subdomain.test.ktirc"))
  76. assertFalse(cert.validFor("domain1.test.ktirc"))
  77. }
  78. @Test
  79. fun `checks all sans`() {
  80. every { cert.subjectAlternativeNames } returns listOf(
  81. listOf(4, "directory.test.ktirc"),
  82. listOf(2, "subdomain1.test.ktirc"),
  83. listOf(2, "subdomain2.test.ktirc"),
  84. listOf(2, "subdomain3.test.ktirc")
  85. )
  86. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  87. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  88. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  89. assertFalse(cert.validFor("directory.test.ktirc"))
  90. }
  91. @Test
  92. fun `checks wildcard sans`() {
  93. every { cert.subjectAlternativeNames } returns listOf(
  94. listOf(4, "directory.test.ktirc"),
  95. listOf(2, "*domain1.test.ktirc"),
  96. listOf(2, "subdomain*.test.ktirc"),
  97. listOf(2, "*foo*.test.ktirc"),
  98. listOf(2, "foo.*.ktirc")
  99. )
  100. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  101. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  102. assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
  103. assertFalse(cert.validFor("foo.test.ktirc"))
  104. }
  105. @Test
  106. fun `still uses CN if sans throws`() {
  107. every { cert.subjectX500Principal } returns mockk {
  108. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  109. }
  110. every { cert.subjectAlternativeNames } throws CertificateException("Oops")
  111. assertTrue(cert.validFor("subdomain.test.ktirc"))
  112. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  113. assertFalse(cert.validFor("testing"))
  114. }
  115. @Test
  116. fun `still uses sans if CN throws`() {
  117. every { cert.subjectX500Principal } throws CertificateException("Oops")
  118. every { cert.subjectAlternativeNames } returns listOf(
  119. listOf(4, "directory.test.ktirc"),
  120. listOf(2, "subdomain1.test.ktirc"),
  121. listOf(2, "subdomain2.test.ktirc"),
  122. listOf(2, "subdomain3.test.ktirc")
  123. )
  124. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  125. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  126. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  127. assertFalse(cert.validFor("directory.test.ktirc"))
  128. }
  129. @Test
  130. fun `fails if CN and sans missing`() {
  131. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  132. assertFalse(cert.validFor("subdomain2.test.KTIRC"))
  133. assertFalse(cert.validFor("subdomain3.test.ktirc"))
  134. assertFalse(cert.validFor("directory.test.ktirc"))
  135. }
  136. }
  137. @Suppress("BlockingMethodInNonBlockingContext")
  138. @Execution(ExecutionMode.SAME_THREAD)
  139. internal class TlsSocketTest {
  140. @Test
  141. fun `can send a string to a server over TLS`() = runBlocking {
  142. withTimeout(5000) {
  143. tlsServerSocket(12321).use { serverSocket ->
  144. val plainSocket = PlainTextSocket(GlobalScope)
  145. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  146. val clientBytesAsync = GlobalScope.async {
  147. ByteArray(13).apply {
  148. serverSocket.accept().getInputStream().read(this)
  149. }
  150. }
  151. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  152. tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
  153. val bytes = clientBytesAsync.await()
  154. assertNotNull(bytes)
  155. assertEquals("Hello World\r\n", String(bytes))
  156. }
  157. }
  158. }
  159. @Test
  160. fun `can read a string from a server over TLS`() = runBlocking<Unit> {
  161. withTimeout(5000) {
  162. tlsServerSocket(12321).use { serverSocket ->
  163. val plainSocket = PlainTextSocket(GlobalScope)
  164. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  165. val socket = GlobalScope.async {
  166. serverSocket.accept().apply {
  167. GlobalScope.launch {
  168. getInputStream().read()
  169. }
  170. }
  171. }
  172. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  173. GlobalScope.launch {
  174. with(socket.await().getOutputStream()) {
  175. write("Hack the planet!".toByteArray())
  176. flush()
  177. }
  178. }
  179. val buffer = tlsSocket.read()
  180. assertNotNull(buffer)
  181. buffer?.let {
  182. assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
  183. }
  184. }
  185. }
  186. }
  187. @Test
  188. fun `read returns null after close`() = runBlocking {
  189. withTimeout(5000) {
  190. tlsServerSocket(12321).use { serverSocket ->
  191. val plainSocket = PlainTextSocket(GlobalScope)
  192. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  193. GlobalScope.launch {
  194. serverSocket.accept().getInputStream().read()
  195. }
  196. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  197. tlsSocket.close()
  198. val buffer = tlsSocket.read()
  199. assertNull(buffer)
  200. }
  201. }
  202. }
  203. @Test
  204. fun `throws if the hostname mismatches`() {
  205. tlsServerSocket(12321).use { serverSocket ->
  206. val plainSocket = PlainTextSocket(GlobalScope)
  207. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
  208. GlobalScope.launch {
  209. serverSocket.accept().getInputStream().read()
  210. }
  211. runBlocking {
  212. withTimeout(5000) {
  213. try {
  214. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  215. fail<Unit>("Expected an exception")
  216. } catch (ex: Exception) {
  217. assertTrue(ex is CertificateException)
  218. }
  219. }
  220. }
  221. }
  222. }
  223. }
  224. internal fun tlsServerSocket(port: Int): ServerSocket {
  225. val keyStore = KeyStore.getInstance("PKCS12")
  226. keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
  227. val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
  228. keyManagerFactory.init(keyStore, CharArray(0))
  229. val sslContext = SSLContext.getInstance("TLSv1.2")
  230. sslContext.init(keyManagerFactory.keyManagers, null, null)
  231. return sslContext.serverSocketFactory.createServerSocket(port)
  232. }
  233. internal fun getTrustingContext() =
  234. SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
  235. internal fun getTrustingManager() = object : X509TrustManager {
  236. override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
  237. override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
  238. override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
  239. }