You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TlsTest.kt 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package com.dmdirc.ktirc.io
  2. import io.mockk.every
  3. import io.mockk.mockk
  4. import kotlinx.coroutines.*
  5. import kotlinx.coroutines.io.writeFully
  6. import kotlinx.io.core.String
  7. import org.junit.jupiter.api.AfterEach
  8. import org.junit.jupiter.api.Assertions.*
  9. import org.junit.jupiter.api.BeforeEach
  10. import org.junit.jupiter.api.Test
  11. import org.junit.jupiter.api.parallel.Execution
  12. import org.junit.jupiter.api.parallel.ExecutionMode
  13. import java.net.InetSocketAddress
  14. import java.net.ServerSocket
  15. import java.security.KeyStore
  16. import java.security.cert.CertificateException
  17. import java.security.cert.X509Certificate
  18. import javax.net.ssl.KeyManagerFactory
  19. import javax.net.ssl.SSLContext
  20. import javax.net.ssl.X509TrustManager
  21. import kotlin.coroutines.CoroutineContext
  22. internal class CertificateValidationTest {
  23. private val cert = mockk<X509Certificate>()
  24. @Test
  25. fun `checks common name`() {
  26. every { cert.subjectX500Principal } returns mockk {
  27. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  28. }
  29. assertTrue(cert.validFor("subdomain.test.ktirc"))
  30. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  31. assertFalse(cert.validFor("testing"))
  32. }
  33. @Test
  34. fun `checks common name with suffixed wildcard`() {
  35. every { cert.subjectX500Principal } returns mockk {
  36. every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
  37. }
  38. assertTrue(cert.validFor("subdomain.test.ktirc"))
  39. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  40. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  41. assertFalse(cert.validFor("1subdomain.test.ktirc"))
  42. }
  43. @Test
  44. fun `checks common name with preixed wildcard`() {
  45. every { cert.subjectX500Principal } returns mockk {
  46. every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
  47. }
  48. assertTrue(cert.validFor("subdomain.test.ktirc"))
  49. assertTrue(cert.validFor("1subdomain.test.ktirc"))
  50. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  51. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  52. }
  53. @Test
  54. fun `checks common name with infixed wildcard`() {
  55. every { cert.subjectX500Principal } returns mockk {
  56. every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
  57. }
  58. assertTrue(cert.validFor("subdomain.test.ktirc"))
  59. assertTrue(cert.validFor("SUB-domain.test.ktirc"))
  60. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  61. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  62. }
  63. @Test
  64. fun `ignores wildcards in CN if they're not left-most`() {
  65. every { cert.subjectX500Principal } returns mockk {
  66. every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
  67. }
  68. assertFalse(cert.validFor("foo.domain.test.ktirc"))
  69. assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
  70. assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
  71. }
  72. @Test
  73. fun `ignores wildcards in CN if there are too many`() {
  74. every { cert.subjectX500Principal } returns mockk {
  75. every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
  76. }
  77. assertFalse(cert.validFor("domain.test.ktirc"))
  78. assertFalse(cert.validFor("subdomain.test.ktirc"))
  79. assertFalse(cert.validFor("domain1.test.ktirc"))
  80. }
  81. @Test
  82. fun `checks all sans`() {
  83. every { cert.subjectAlternativeNames } returns listOf(
  84. listOf(4, "directory.test.ktirc"),
  85. listOf(2, "subdomain1.test.ktirc"),
  86. listOf(2, "subdomain2.test.ktirc"),
  87. listOf(2, "subdomain3.test.ktirc")
  88. )
  89. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  90. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  91. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  92. assertFalse(cert.validFor("directory.test.ktirc"))
  93. }
  94. @Test
  95. fun `checks wildcard sans`() {
  96. every { cert.subjectAlternativeNames } returns listOf(
  97. listOf(4, "directory.test.ktirc"),
  98. listOf(2, "*domain1.test.ktirc"),
  99. listOf(2, "subdomain*.test.ktirc"),
  100. listOf(2, "*foo*.test.ktirc"),
  101. listOf(2, "foo.*.ktirc")
  102. )
  103. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  104. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  105. assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
  106. assertFalse(cert.validFor("foo.test.ktirc"))
  107. }
  108. @Test
  109. fun `still uses CN if sans throws`() {
  110. every { cert.subjectX500Principal } returns mockk {
  111. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  112. }
  113. every { cert.subjectAlternativeNames } throws CertificateException("Oops")
  114. assertTrue(cert.validFor("subdomain.test.ktirc"))
  115. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  116. assertFalse(cert.validFor("testing"))
  117. }
  118. @Test
  119. fun `still uses sans if CN throws`() {
  120. every { cert.subjectX500Principal } throws CertificateException("Oops")
  121. every { cert.subjectAlternativeNames } returns listOf(
  122. listOf(4, "directory.test.ktirc"),
  123. listOf(2, "subdomain1.test.ktirc"),
  124. listOf(2, "subdomain2.test.ktirc"),
  125. listOf(2, "subdomain3.test.ktirc")
  126. )
  127. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  128. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  129. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  130. assertFalse(cert.validFor("directory.test.ktirc"))
  131. }
  132. @Test
  133. fun `fails if CN and sans missing`() {
  134. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  135. assertFalse(cert.validFor("subdomain2.test.KTIRC"))
  136. assertFalse(cert.validFor("subdomain3.test.ktirc"))
  137. assertFalse(cert.validFor("directory.test.ktirc"))
  138. }
  139. }
  140. @Suppress("BlockingMethodInNonBlockingContext")
  141. @Execution(ExecutionMode.SAME_THREAD)
  142. internal class TlsSocketTest: CoroutineScope {
  143. override var coroutineContext: CoroutineContext = GlobalScope.coroutineContext
  144. @ObsoleteCoroutinesApi
  145. @BeforeEach
  146. fun setup() {
  147. coroutineContext = newFixedThreadPoolContext(4, "tls-test")
  148. }
  149. @AfterEach
  150. fun teardown() {
  151. coroutineContext.cancel()
  152. }
  153. @Test
  154. fun `can send a string to a server over TLS`() = runBlocking(coroutineContext) {
  155. withTimeout(5000) {
  156. tlsServerSocket(12321).use { serverSocket ->
  157. val plainSocket = PlainTextSocket(this@TlsSocketTest)
  158. val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
  159. val clientBytesAsync = this@TlsSocketTest.async {
  160. ByteArray(13).apply {
  161. serverSocket.accept().getInputStream().read(this)
  162. }
  163. }
  164. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  165. tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
  166. val bytes = clientBytesAsync.await()
  167. assertNotNull(bytes)
  168. assertEquals("Hello World\r\n", String(bytes))
  169. }
  170. }
  171. }
  172. @Test
  173. fun `can read a string from a server over TLS`() = runBlocking<Unit>(coroutineContext) {
  174. withTimeout(5000) {
  175. tlsServerSocket(12321).use { serverSocket ->
  176. val plainSocket = PlainTextSocket(this@TlsSocketTest)
  177. val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
  178. val socket = this@TlsSocketTest.async {
  179. serverSocket.accept().apply {
  180. this@TlsSocketTest.launch {
  181. getInputStream().read()
  182. }
  183. }
  184. }
  185. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  186. this@TlsSocketTest.launch {
  187. with(socket.await().getOutputStream()) {
  188. write("Hack the planet!".toByteArray())
  189. flush()
  190. }
  191. }
  192. val buffer = tlsSocket.read()
  193. assertNotNull(buffer)
  194. buffer?.let {
  195. assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
  196. }
  197. }
  198. }
  199. }
  200. @Test
  201. fun `read returns null after close`() = runBlocking(coroutineContext) {
  202. withTimeout(5000) {
  203. tlsServerSocket(12321).use { serverSocket ->
  204. val plainSocket = PlainTextSocket(this@TlsSocketTest)
  205. val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
  206. this@TlsSocketTest.launch {
  207. serverSocket.accept().getInputStream().read()
  208. }
  209. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  210. tlsSocket.close()
  211. val buffer = tlsSocket.read()
  212. assertNull(buffer)
  213. }
  214. }
  215. }
  216. @Test
  217. fun `throws if the hostname mismatches`() {
  218. tlsServerSocket(12321).use { serverSocket ->
  219. val plainSocket = PlainTextSocket(this@TlsSocketTest)
  220. val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "127.0.0.1")
  221. launch {
  222. serverSocket.accept().getInputStream().read()
  223. }
  224. runBlocking(coroutineContext) {
  225. withTimeout(5000) {
  226. try {
  227. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  228. fail<Unit>("Expected an exception")
  229. } catch (ex: Exception) {
  230. assertTrue(ex is CertificateException)
  231. }
  232. }
  233. }
  234. }
  235. }
  236. }
  237. internal fun tlsServerSocket(port: Int): ServerSocket {
  238. val keyStore = KeyStore.getInstance("PKCS12")
  239. keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
  240. val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
  241. keyManagerFactory.init(keyStore, CharArray(0))
  242. val sslContext = SSLContext.getInstance("TLSv1.2")
  243. sslContext.init(keyManagerFactory.keyManagers, null, null)
  244. return sslContext.serverSocketFactory.createServerSocket(port)
  245. }
  246. internal fun getTrustingContext() =
  247. SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
  248. internal fun getTrustingManager() = object : X509TrustManager {
  249. override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
  250. override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
  251. override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
  252. }