Browse Source

Create a separate coroutine context per test

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
136329c27d
1 changed files with 35 additions and 21 deletions
  1. 35
    21
      src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt

+ 35
- 21
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

5
 import kotlinx.coroutines.*
5
 import kotlinx.coroutines.*
6
 import kotlinx.coroutines.io.writeFully
6
 import kotlinx.coroutines.io.writeFully
7
 import kotlinx.io.core.String
7
 import kotlinx.io.core.String
8
+import org.junit.jupiter.api.AfterEach
8
 import org.junit.jupiter.api.Assertions.*
9
 import org.junit.jupiter.api.Assertions.*
10
+import org.junit.jupiter.api.BeforeEach
9
 import org.junit.jupiter.api.Test
11
 import org.junit.jupiter.api.Test
10
 import org.junit.jupiter.api.parallel.Execution
12
 import org.junit.jupiter.api.parallel.Execution
11
 import org.junit.jupiter.api.parallel.ExecutionMode
13
 import org.junit.jupiter.api.parallel.ExecutionMode
17
 import javax.net.ssl.KeyManagerFactory
19
 import javax.net.ssl.KeyManagerFactory
18
 import javax.net.ssl.SSLContext
20
 import javax.net.ssl.SSLContext
19
 import javax.net.ssl.X509TrustManager
21
 import javax.net.ssl.X509TrustManager
22
+import kotlin.coroutines.CoroutineContext
20
 
23
 
21
 internal class CertificateValidationTest {
24
 internal class CertificateValidationTest {
22
 
25
 
163
 
166
 
164
 @Suppress("BlockingMethodInNonBlockingContext")
167
 @Suppress("BlockingMethodInNonBlockingContext")
165
 @Execution(ExecutionMode.SAME_THREAD)
168
 @Execution(ExecutionMode.SAME_THREAD)
166
-internal class TlsSocketTest {
169
+internal class TlsSocketTest: CoroutineScope {
170
+
171
+    override var coroutineContext: CoroutineContext = GlobalScope.coroutineContext
172
+
173
+    @ObsoleteCoroutinesApi
174
+    @BeforeEach
175
+    fun setup() {
176
+        coroutineContext = newFixedThreadPoolContext(4, "tls-test")
177
+    }
178
+
179
+    @AfterEach
180
+    fun teardown() {
181
+        coroutineContext.cancel()
182
+    }
167
 
183
 
168
     @Test
184
     @Test
169
-    fun `can send a string to a server over TLS`() = runBlocking {
185
+    fun `can send a string to a server over TLS`() = runBlocking(coroutineContext) {
170
         withTimeout(5000) {
186
         withTimeout(5000) {
171
             tlsServerSocket(12321).use { serverSocket ->
187
             tlsServerSocket(12321).use { serverSocket ->
172
-                val plainSocket = PlainTextSocket(GlobalScope)
173
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
174
-                val clientBytesAsync = GlobalScope.async {
188
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
189
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
190
+                val clientBytesAsync = this@TlsSocketTest.async {
175
                     ByteArray(13).apply {
191
                     ByteArray(13).apply {
176
                         serverSocket.accept().getInputStream().read(this)
192
                         serverSocket.accept().getInputStream().read(this)
177
                     }
193
                     }
188
     }
204
     }
189
 
205
 
190
     @Test
206
     @Test
191
-    fun `can read a string from a server over TLS`() = runBlocking<Unit> {
207
+    fun `can read a string from a server over TLS`() = runBlocking<Unit>(coroutineContext) {
192
         withTimeout(5000) {
208
         withTimeout(5000) {
193
             tlsServerSocket(12321).use { serverSocket ->
209
             tlsServerSocket(12321).use { serverSocket ->
194
-                val plainSocket = PlainTextSocket(GlobalScope)
195
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
-                val socket = GlobalScope.async {
210
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
211
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
212
+                val socket = this@TlsSocketTest.async {
197
                     serverSocket.accept().apply {
213
                     serverSocket.accept().apply {
198
-                        GlobalScope.launch {
214
+                        this@TlsSocketTest.launch {
199
                             getInputStream().read()
215
                             getInputStream().read()
200
                         }
216
                         }
201
                     }
217
                     }
203
 
219
 
204
                 tlsSocket.connect(InetSocketAddress("localhost", 12321))
220
                 tlsSocket.connect(InetSocketAddress("localhost", 12321))
205
 
221
 
206
-                GlobalScope.launch {
222
+                this@TlsSocketTest.launch {
207
                     with(socket.await().getOutputStream()) {
223
                     with(socket.await().getOutputStream()) {
208
                         write("Hack the planet!".toByteArray())
224
                         write("Hack the planet!".toByteArray())
209
                         flush()
225
                         flush()
221
     }
237
     }
222
 
238
 
223
     @Test
239
     @Test
224
-    fun `read returns null after close`() = runBlocking {
240
+    fun `read returns null after close`() = runBlocking(coroutineContext) {
225
         withTimeout(5000) {
241
         withTimeout(5000) {
226
             tlsServerSocket(12321).use { serverSocket ->
242
             tlsServerSocket(12321).use { serverSocket ->
227
-                val plainSocket = PlainTextSocket(GlobalScope)
228
-                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
229
-                GlobalScope.launch {
243
+                val plainSocket = PlainTextSocket(this@TlsSocketTest)
244
+                val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "localhost")
245
+                this@TlsSocketTest.launch {
230
                     serverSocket.accept().getInputStream().read()
246
                     serverSocket.accept().getInputStream().read()
231
                 }
247
                 }
232
 
248
 
244
     @Test
260
     @Test
245
     fun `throws if the hostname mismatches`() {
261
     fun `throws if the hostname mismatches`() {
246
         tlsServerSocket(12321).use { serverSocket ->
262
         tlsServerSocket(12321).use { serverSocket ->
247
-            val plainSocket = PlainTextSocket(GlobalScope)
248
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "127.0.0.1")
249
-            GlobalScope.launch {
263
+            val plainSocket = PlainTextSocket(this@TlsSocketTest)
264
+            val tlsSocket = TlsSocket(this@TlsSocketTest, plainSocket, getTrustingContext(), "127.0.0.1")
265
+            launch {
250
                 serverSocket.accept().getInputStream().read()
266
                 serverSocket.accept().getInputStream().read()
251
             }
267
             }
252
 
268
 
253
-            runBlocking {
269
+            runBlocking(coroutineContext) {
254
                 withTimeout(5000) {
270
                 withTimeout(5000) {
255
                     try {
271
                     try {
256
                         tlsSocket.connect(InetSocketAddress("localhost", 12321))
272
                         tlsSocket.connect(InetSocketAddress("localhost", 12321))
262
             }
278
             }
263
         }
279
         }
264
     }
280
     }
265
-
266
-
267
 }
281
 }
268
 
282
 
269
 internal fun tlsServerSocket(port: Int): ServerSocket {
283
 internal fun tlsServerSocket(port: Int): ServerSocket {

Loading…
Cancel
Save