Browse Source

Create a separate coroutine context per test

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
136329c27d
1 changed files with 35 additions and 21 deletions
  1. 35
    21
      src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt

+ 35
- 21
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

@@ -5,7 +5,9 @@ import io.mockk.mockk
5 5
 import kotlinx.coroutines.*
6 6
 import kotlinx.coroutines.io.writeFully
7 7
 import kotlinx.io.core.String
8
+import org.junit.jupiter.api.AfterEach
8 9
 import org.junit.jupiter.api.Assertions.*
10
+import org.junit.jupiter.api.BeforeEach
9 11
 import org.junit.jupiter.api.Test
10 12
 import org.junit.jupiter.api.parallel.Execution
11 13
 import org.junit.jupiter.api.parallel.ExecutionMode
@@ -17,6 +19,7 @@ import java.security.cert.X509Certificate
17 19
 import javax.net.ssl.KeyManagerFactory
18 20
 import javax.net.ssl.SSLContext
19 21
 import javax.net.ssl.X509TrustManager
22
+import kotlin.coroutines.CoroutineContext
20 23
 
21 24
 internal class CertificateValidationTest {
22 25
 
@@ -163,15 +166,28 @@ internal class CertificateValidationTest {
163 166
 
164 167
 @Suppress("BlockingMethodInNonBlockingContext")
165 168
 @Execution(ExecutionMode.SAME_THREAD)
166
-internal class TlsSocketTest {
169
+internal class TlsSocketTest: CoroutineScope {
170
+
171
+    override var coroutineContext: CoroutineContext = GlobalScope.coroutineContext
172
+
173
+    @ObsoleteCoroutinesApi
174
+    @BeforeEach
175
+    fun setup() {
176
+        coroutineContext = newFixedThreadPoolContext(4, "tls-test")
177
+    }
178
+
179
+    @AfterEach
180
+    fun teardown() {
181
+        coroutineContext.cancel()
182
+    }
167 183
 
168 184
     @Test
169
-    fun `can send a string to a server over TLS`() = runBlocking {
185
+    fun `can send a string to a server over TLS`() = runBlocking(coroutineContext) {
170 186
         withTimeout(5000) {
171 187
             tlsServerSocket(12321).use { serverSocket ->
172
-                val plainSocket = PlainTextSocket(GlobalScope)
173
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
174
-                val clientBytesAsync = GlobalScope.async {
188
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
189
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
190
+                val clientBytesAsync = this@TlsSocketTest.async {
175 191
                     ByteArray(13).apply {
176 192
                         serverSocket.accept().getInputStream().read(this)
177 193
                     }
@@ -188,14 +204,14 @@ internal class TlsSocketTest {
188 204
     }
189 205
 
190 206
     @Test
191
-    fun `can read a string from a server over TLS`() = runBlocking<Unit> {
207
+    fun `can read a string from a server over TLS`() = runBlocking<Unit>(coroutineContext) {
192 208
         withTimeout(5000) {
193 209
             tlsServerSocket(12321).use { serverSocket ->
194
-                val plainSocket = PlainTextSocket(GlobalScope)
195
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
-                val socket = GlobalScope.async {
210
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
211
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
212
+                val socket = this@TlsSocketTest.async {
197 213
                     serverSocket.accept().apply {
198
-                        GlobalScope.launch {
214
+                        this@TlsSocketTest.launch {
199 215
                             getInputStream().read()
200 216
                         }
201 217
                     }
@@ -203,7 +219,7 @@ internal class TlsSocketTest {
203 219
 
204 220
                 tlsSocket.connect(InetSocketAddress("localhost", 12321))
205 221
 
206
-                GlobalScope.launch {
222
+                this@TlsSocketTest.launch {
207 223
                     with(socket.await().getOutputStream()) {
208 224
                         write("Hack the planet!".toByteArray())
209 225
                         flush()
@@ -221,12 +237,12 @@ internal class TlsSocketTest {
221 237
     }
222 238
 
223 239
     @Test
224
-    fun `read returns null after close`() = runBlocking {
240
+    fun `read returns null after close`() = runBlocking(coroutineContext) {
225 241
         withTimeout(5000) {
226 242
             tlsServerSocket(12321).use { serverSocket ->
227
-                val plainSocket = PlainTextSocket(GlobalScope)
228
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
229
-                GlobalScope.launch {
243
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
244
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
245
+                this@TlsSocketTest.launch {
230 246
                     serverSocket.accept().getInputStream().read()
231 247
                 }
232 248
 
@@ -244,13 +260,13 @@ internal class TlsSocketTest {
244 260
     @Test
245 261
     fun `throws if the hostname mismatches`() {
246 262
         tlsServerSocket(12321).use { serverSocket ->
247
-            val plainSocket = PlainTextSocket(GlobalScope)
248
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
249
-            GlobalScope.launch {
263
+            val plainSocket = PlainTextSocket(this@TlsSocketTest)
264
+            val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "127.0.0.1")
265
+            launch {
250 266
                 serverSocket.accept().getInputStream().read()
251 267
             }
252 268
 
253
-            runBlocking {
269
+            runBlocking(coroutineContext) {
254 270
                 withTimeout(5000) {
255 271
                     try {
256 272
                         tlsSocket.connect(InetSocketAddress("localhost", 12321))
@@ -262,8 +278,6 @@ internal class TlsSocketTest {
262 278
             }
263 279
         }
264 280
     }
265
-
266
-
267 281
 }
268 282
 
269 283
 internal fun tlsServerSocket(port: Int): ServerSocket {

Loading…
Cancel
Save