|
@@ -5,7 +5,9 @@ import io.mockk.mockk
|
5
|
5
|
import kotlinx.coroutines.*
|
6
|
6
|
import kotlinx.coroutines.io.writeFully
|
7
|
7
|
import kotlinx.io.core.String
|
|
8
|
+import org.junit.jupiter.api.AfterEach
|
8
|
9
|
import org.junit.jupiter.api.Assertions.*
|
|
10
|
+import org.junit.jupiter.api.BeforeEach
|
9
|
11
|
import org.junit.jupiter.api.Test
|
10
|
12
|
import org.junit.jupiter.api.parallel.Execution
|
11
|
13
|
import org.junit.jupiter.api.parallel.ExecutionMode
|
|
@@ -17,6 +19,7 @@ import java.security.cert.X509Certificate
|
17
|
19
|
import javax.net.ssl.KeyManagerFactory
|
18
|
20
|
import javax.net.ssl.SSLContext
|
19
|
21
|
import javax.net.ssl.X509TrustManager
|
|
22
|
+import kotlin.coroutines.CoroutineContext
|
20
|
23
|
|
21
|
24
|
internal class CertificateValidationTest {
|
22
|
25
|
|
|
@@ -163,15 +166,28 @@ internal class CertificateValidationTest {
|
163
|
166
|
|
164
|
167
|
@Suppress("BlockingMethodInNonBlockingContext")
|
165
|
168
|
@Execution(ExecutionMode.SAME_THREAD)
|
166
|
|
-internal class TlsSocketTest {
|
|
169
|
+internal class TlsSocketTest: CoroutineScope {
|
|
170
|
+
|
|
171
|
+ override var coroutineContext: CoroutineContext = GlobalScope.coroutineContext
|
|
172
|
+
|
|
173
|
+ @ObsoleteCoroutinesApi
|
|
174
|
+ @BeforeEach
|
|
175
|
+ fun setup() {
|
|
176
|
+ coroutineContext = newFixedThreadPoolContext(4, "tls-test")
|
|
177
|
+ }
|
|
178
|
+
|
|
179
|
+ @AfterEach
|
|
180
|
+ fun teardown() {
|
|
181
|
+ coroutineContext.cancel()
|
|
182
|
+ }
|
167
|
183
|
|
168
|
184
|
@Test
|
169
|
|
- fun `can send a string to a server over TLS`() = runBlocking {
|
|
185
|
+ fun `can send a string to a server over TLS`() = runBlocking(coroutineContext) {
|
170
|
186
|
withTimeout(5000) {
|
171
|
187
|
tlsServerSocket(12321).use { serverSocket ->
|
172
|
|
- val plainSocket = PlainTextSocket(GlobalScope)
|
173
|
|
- val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
|
174
|
|
- val clientBytesAsync = GlobalScope.async {
|
|
188
|
+ val plainSocket = PlainTextSocket(this@TlsSocketTest)
|
|
189
|
+ val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
|
|
190
|
+ val clientBytesAsync = this@TlsSocketTest.async {
|
175
|
191
|
ByteArray(13).apply {
|
176
|
192
|
serverSocket.accept().getInputStream().read(this)
|
177
|
193
|
}
|
|
@@ -188,14 +204,14 @@ internal class TlsSocketTest {
|
188
|
204
|
}
|
189
|
205
|
|
190
|
206
|
@Test
|
191
|
|
- fun `can read a string from a server over TLS`() = runBlocking<Unit> {
|
|
207
|
+ fun `can read a string from a server over TLS`() = runBlocking<Unit>(coroutineContext) {
|
192
|
208
|
withTimeout(5000) {
|
193
|
209
|
tlsServerSocket(12321).use { serverSocket ->
|
194
|
|
- val plainSocket = PlainTextSocket(GlobalScope)
|
195
|
|
- val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
|
196
|
|
- val socket = GlobalScope.async {
|
|
210
|
+ val plainSocket = PlainTextSocket(this@TlsSocketTest)
|
|
211
|
+ val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
|
|
212
|
+ val socket = this@TlsSocketTest.async {
|
197
|
213
|
serverSocket.accept().apply {
|
198
|
|
- GlobalScope.launch {
|
|
214
|
+ this@TlsSocketTest.launch {
|
199
|
215
|
getInputStream().read()
|
200
|
216
|
}
|
201
|
217
|
}
|
|
@@ -203,7 +219,7 @@ internal class TlsSocketTest {
|
203
|
219
|
|
204
|
220
|
tlsSocket.connect(InetSocketAddress("localhost", 12321))
|
205
|
221
|
|
206
|
|
- GlobalScope.launch {
|
|
222
|
+ this@TlsSocketTest.launch {
|
207
|
223
|
with(socket.await().getOutputStream()) {
|
208
|
224
|
write("Hack the planet!".toByteArray())
|
209
|
225
|
flush()
|
|
@@ -221,12 +237,12 @@ internal class TlsSocketTest {
|
221
|
237
|
}
|
222
|
238
|
|
223
|
239
|
@Test
|
224
|
|
- fun `read returns null after close`() = runBlocking {
|
|
240
|
+ fun `read returns null after close`() = runBlocking(coroutineContext) {
|
225
|
241
|
withTimeout(5000) {
|
226
|
242
|
tlsServerSocket(12321).use { serverSocket ->
|
227
|
|
- val plainSocket = PlainTextSocket(GlobalScope)
|
228
|
|
- val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
|
229
|
|
- GlobalScope.launch {
|
|
243
|
+ val plainSocket = PlainTextSocket(this@TlsSocketTest)
|
|
244
|
+ val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
|
|
245
|
+ this@TlsSocketTest.launch {
|
230
|
246
|
serverSocket.accept().getInputStream().read()
|
231
|
247
|
}
|
232
|
248
|
|
|
@@ -244,13 +260,13 @@ internal class TlsSocketTest {
|
244
|
260
|
@Test
|
245
|
261
|
fun `throws if the hostname mismatches`() {
|
246
|
262
|
tlsServerSocket(12321).use { serverSocket ->
|
247
|
|
- val plainSocket = PlainTextSocket(GlobalScope)
|
248
|
|
- val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
|
249
|
|
- GlobalScope.launch {
|
|
263
|
+ val plainSocket = PlainTextSocket(this@TlsSocketTest)
|
|
264
|
+ val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "127.0.0.1")
|
|
265
|
+ launch {
|
250
|
266
|
serverSocket.accept().getInputStream().read()
|
251
|
267
|
}
|
252
|
268
|
|
253
|
|
- runBlocking {
|
|
269
|
+ runBlocking(coroutineContext) {
|
254
|
270
|
withTimeout(5000) {
|
255
|
271
|
try {
|
256
|
272
|
tlsSocket.connect(InetSocketAddress("localhost", 12321))
|
|
@@ -262,8 +278,6 @@ internal class TlsSocketTest {
|
262
|
278
|
}
|
263
|
279
|
}
|
264
|
280
|
}
|
265
|
|
-
|
266
|
|
-
|
267
|
281
|
}
|
268
|
282
|
|
269
|
283
|
internal fun tlsServerSocket(port: Int): ServerSocket {
|