您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

TlsTest.kt 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package com.dmdirc.ktirc.io
  2. import io.mockk.every
  3. import io.mockk.mockk
  4. import kotlinx.coroutines.*
  5. import kotlinx.coroutines.io.writeFully
  6. import kotlinx.io.core.String
  7. import org.junit.jupiter.api.Assertions.*
  8. import org.junit.jupiter.api.Test
  9. import org.junit.jupiter.api.parallel.Execution
  10. import org.junit.jupiter.api.parallel.ExecutionMode
  11. import java.net.InetSocketAddress
  12. import java.net.ServerSocket
  13. import java.security.KeyStore
  14. import java.security.cert.CertificateException
  15. import java.security.cert.X509Certificate
  16. import javax.net.ssl.KeyManagerFactory
  17. import javax.net.ssl.SSLContext
  18. import javax.net.ssl.X509TrustManager
  19. internal class CertificateValidationTest {
  20. private val cert = mockk<X509Certificate>()
  21. @Test
  22. fun `checks common name`() {
  23. every { cert.subjectX500Principal } returns mockk {
  24. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  25. }
  26. assertTrue(cert.validFor("subdomain.test.ktirc"))
  27. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  28. assertFalse(cert.validFor("testing"))
  29. }
  30. @Test
  31. fun `checks common name with suffixed wildcard`() {
  32. every { cert.subjectX500Principal } returns mockk {
  33. every { name } returns "CN=subdomain*.test.ktirc,O=testing,L=London,C=GB"
  34. }
  35. assertTrue(cert.validFor("subdomain.test.ktirc"))
  36. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  37. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  38. assertFalse(cert.validFor("1subdomain.test.ktirc"))
  39. }
  40. @Test
  41. fun `checks common name with preixed wildcard`() {
  42. every { cert.subjectX500Principal } returns mockk {
  43. every { name } returns "CN=*subdomain.test.ktirc,O=testing,L=London,C=GB"
  44. }
  45. assertTrue(cert.validFor("subdomain.test.ktirc"))
  46. assertTrue(cert.validFor("1subdomain.test.ktirc"))
  47. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  48. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  49. }
  50. @Test
  51. fun `checks common name with infixed wildcard`() {
  52. every { cert.subjectX500Principal } returns mockk {
  53. every { name } returns "CN=sub*domain.test.ktirc,O=testing,L=London,C=GB"
  54. }
  55. assertTrue(cert.validFor("subdomain.test.ktirc"))
  56. assertTrue(cert.validFor("SUB-domain.test.ktirc"))
  57. assertFalse(cert.validFor("foo.subdomain.test.ktirc"))
  58. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  59. }
  60. @Test
  61. fun `ignores wildcards in CN if they're not left-most`() {
  62. every { cert.subjectX500Principal } returns mockk {
  63. every { name } returns "CN=foo.*domain.test.ktirc,O=testing,L=London,C=GB"
  64. }
  65. assertFalse(cert.validFor("foo.domain.test.ktirc"))
  66. assertFalse(cert.validFor("foo-test.domain.test.ktirc"))
  67. assertFalse(cert.validFor("foo.test-domain.test.ktirc"))
  68. }
  69. @Test
  70. fun `ignores wildcards in CN if there are too many`() {
  71. every { cert.subjectX500Principal } returns mockk {
  72. every { name } returns "CN=*domain*.test.ktirc,O=testing,L=London,C=GB"
  73. }
  74. assertFalse(cert.validFor("domain.test.ktirc"))
  75. assertFalse(cert.validFor("subdomain.test.ktirc"))
  76. assertFalse(cert.validFor("domain1.test.ktirc"))
  77. }
  78. @Test
  79. fun `checks all sans`() {
  80. every { cert.subjectAlternativeNames } returns listOf(
  81. listOf(4, "directory.test.ktirc"),
  82. listOf(2, "subdomain1.test.ktirc"),
  83. listOf(2, "subdomain2.test.ktirc"),
  84. listOf(2, "subdomain3.test.ktirc")
  85. )
  86. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  87. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  88. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  89. assertFalse(cert.validFor("directory.test.ktirc"))
  90. }
  91. @Test
  92. fun `checks wildcard sans`() {
  93. every { cert.subjectAlternativeNames } returns listOf(
  94. listOf(4, "directory.test.ktirc"),
  95. listOf(2, "*domain1.test.ktirc"),
  96. listOf(2, "subdomain*.test.ktirc"),
  97. listOf(2, "*foo*.test.ktirc"),
  98. listOf(2, "foo.*.ktirc")
  99. )
  100. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  101. assertTrue(cert.validFor("subdomain2.test.ktirc"))
  102. assertTrue(cert.validFor("gooddomain1.TEST.ktirc"))
  103. assertFalse(cert.validFor("foo.test.ktirc"))
  104. }
  105. @Test
  106. fun `still uses CN if sans throws`() {
  107. every { cert.subjectX500Principal } returns mockk {
  108. every { name } returns "CN=subdomain.test.ktirc,O=testing,L=London,C=GB"
  109. }
  110. every { cert.subjectAlternativeNames } throws CertificateException("Oops")
  111. assertTrue(cert.validFor("subdomain.test.ktirc"))
  112. assertFalse(cert.validFor("subdomain2.test.ktirc"))
  113. assertFalse(cert.validFor("testing"))
  114. }
  115. @Test
  116. fun `still uses sans if CN throws`() {
  117. every { cert.subjectX500Principal } throws CertificateException("Oops")
  118. every { cert.subjectAlternativeNames } returns listOf(
  119. listOf(4, "directory.test.ktirc"),
  120. listOf(2, "subdomain1.test.ktirc"),
  121. listOf(2, "subdomain2.test.ktirc"),
  122. listOf(2, "subdomain3.test.ktirc")
  123. )
  124. assertTrue(cert.validFor("subdomain1.test.ktirc"))
  125. assertTrue(cert.validFor("subdomain2.test.KTIRC"))
  126. assertTrue(cert.validFor("subdomain3.test.ktirc"))
  127. assertFalse(cert.validFor("directory.test.ktirc"))
  128. }
  129. @Test
  130. fun `fails if CN and sans missing`() {
  131. assertFalse(cert.validFor("subdomain1.test.ktirc"))
  132. assertFalse(cert.validFor("subdomain2.test.KTIRC"))
  133. assertFalse(cert.validFor("subdomain3.test.ktirc"))
  134. assertFalse(cert.validFor("directory.test.ktirc"))
  135. }
  136. }
  137. @Suppress("BlockingMethodInNonBlockingContext")
  138. @Execution(ExecutionMode.SAME_THREAD)
  139. internal class TlsSocketTest {
  140. @Test
  141. fun `can send a string to a server over TLS`() = runBlocking {
  142. withTimeout(5000) {
  143. tlsServerSocket(12321).use { serverSocket ->
  144. val plainSocket = PlainTextSocket(GlobalScope)
  145. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  146. val clientBytesAsync = GlobalScope.async {
  147. ByteArray(13).apply {
  148. serverSocket.accept().getInputStream().read(this)
  149. }
  150. }
  151. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  152. tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
  153. val bytes = clientBytesAsync.await()
  154. assertNotNull(bytes)
  155. assertEquals("Hello World\r\n", String(bytes))
  156. }
  157. }
  158. }
  159. @Test
  160. fun `can read a string from a server over TLS`() = runBlocking<Unit> {
  161. withTimeout(5000) {
  162. tlsServerSocket(12321).use { serverSocket ->
  163. val plainSocket = PlainTextSocket(GlobalScope)
  164. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  165. val socket = GlobalScope.async {
  166. serverSocket.accept().apply {
  167. GlobalScope.launch {
  168. getInputStream().read()
  169. }
  170. }
  171. }
  172. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  173. GlobalScope.launch {
  174. with(socket.await().getOutputStream()) {
  175. write("Hack the planet!".toByteArray())
  176. flush()
  177. }
  178. }
  179. val buffer = tlsSocket.read()
  180. assertNotNull(buffer)
  181. buffer?.let {
  182. assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
  183. }
  184. }
  185. }
  186. }
  187. @Test
  188. fun `read returns null after close`() = runBlocking {
  189. withTimeout(5000) {
  190. tlsServerSocket(12321).use { serverSocket ->
  191. val plainSocket = PlainTextSocket(GlobalScope)
  192. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
  193. GlobalScope.launch {
  194. serverSocket.accept().getInputStream().read()
  195. }
  196. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  197. tlsSocket.close()
  198. val buffer = tlsSocket.read()
  199. assertNull(buffer)
  200. }
  201. }
  202. }
  203. @Test
  204. fun `throws if the hostname mismatches`() {
  205. tlsServerSocket(12321).use { serverSocket ->
  206. val plainSocket = PlainTextSocket(GlobalScope)
  207. val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
  208. GlobalScope.launch {
  209. serverSocket.accept().getInputStream().read()
  210. }
  211. runBlocking {
  212. withTimeout(5000) {
  213. try {
  214. tlsSocket.connect(InetSocketAddress("localhost", 12321))
  215. fail<Unit>("Expected an exception")
  216. } catch (ex: Exception) {
  217. assertTrue(ex is CertificateException)
  218. }
  219. }
  220. }
  221. }
  222. }
  223. }
  224. internal fun tlsServerSocket(port: Int): ServerSocket {
  225. val keyStore = KeyStore.getInstance("PKCS12")
  226. keyStore.load(CertificateValidationTest::class.java.getResourceAsStream("localhost.p12"), CharArray(0))
  227. val keyManagerFactory = KeyManagerFactory.getInstance("PKIX")
  228. keyManagerFactory.init(keyStore, CharArray(0))
  229. val sslContext = SSLContext.getInstance("TLSv1.2")
  230. sslContext.init(keyManagerFactory.keyManagers, null, null)
  231. return sslContext.serverSocketFactory.createServerSocket(port)
  232. }
  233. internal fun getTrustingContext() =
  234. SSLContext.getInstance("TLSv1.2").apply { init(null, arrayOf(getTrustingManager()), null) }
  235. internal fun getTrustingManager() = object : X509TrustManager {
  236. override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
  237. override fun checkClientTrusted(certs: Array<X509Certificate>, authType: String) {}
  238. override fun checkServerTrusted(certs: Array<X509Certificate>, authType: String) {}
  239. }