Browse Source

Add timeout to TLS tests

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
8f27a42b4e
1 changed files with 54 additions and 49 deletions
  1. 54
    49
      src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt

+ 54
- 49
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

@@ -2,11 +2,8 @@ package com.dmdirc.ktirc.io
2 2
 
3 3
 import io.mockk.every
4 4
 import io.mockk.mockk
5
-import kotlinx.coroutines.GlobalScope
6
-import kotlinx.coroutines.async
5
+import kotlinx.coroutines.*
7 6
 import kotlinx.coroutines.io.writeFully
8
-import kotlinx.coroutines.launch
9
-import kotlinx.coroutines.runBlocking
10 7
 import kotlinx.io.core.String
11 8
 import org.junit.jupiter.api.Assertions.*
12 9
 import org.junit.jupiter.api.Test
@@ -170,71 +167,77 @@ internal class TlsSocketTest {
170 167
 
171 168
     @Test
172 169
     fun `can send a string to a server over TLS`() = runBlocking {
173
-        tlsServerSocket(12321).use { serverSocket ->
174
-            val plainSocket = PlainTextSocket(GlobalScope)
175
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
176
-            val clientBytesAsync = GlobalScope.async {
177
-                ByteArray(13).apply {
178
-                    serverSocket.accept().getInputStream().read(this)
170
+        withTimeout(5000) {
171
+            tlsServerSocket(12321).use { serverSocket ->
172
+                val plainSocket = PlainTextSocket(GlobalScope)
173
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
174
+                val clientBytesAsync = GlobalScope.async {
175
+                    ByteArray(13).apply {
176
+                        serverSocket.accept().getInputStream().read(this)
177
+                    }
179 178
                 }
180
-            }
181 179
 
182
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
183
-            tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
180
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
181
+                tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
184 182
 
185
-            val bytes = clientBytesAsync.await()
186
-            assertNotNull(bytes)
187
-            assertEquals("Hello World\r\n", String(bytes))
183
+                val bytes = clientBytesAsync.await()
184
+                assertNotNull(bytes)
185
+                assertEquals("Hello World\r\n", String(bytes))
186
+            }
188 187
         }
189 188
     }
190 189
 
191 190
     @Test
192 191
     fun `can read a string from a server over TLS`() = runBlocking<Unit> {
193
-        tlsServerSocket(12321).use { serverSocket ->
194
-            val plainSocket = PlainTextSocket(GlobalScope)
195
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
-            val socket = GlobalScope.async {
197
-                serverSocket.accept().apply {
198
-                    GlobalScope.launch {
199
-                        getInputStream().read()
192
+        withTimeout(5000) {
193
+            tlsServerSocket(12321).use { serverSocket ->
194
+                val plainSocket = PlainTextSocket(GlobalScope)
195
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
+                val socket = GlobalScope.async {
197
+                    serverSocket.accept().apply {
198
+                        GlobalScope.launch {
199
+                            getInputStream().read()
200
+                        }
200 201
                     }
201 202
                 }
202
-            }
203 203
 
204
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
204
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
205 205
 
206
-            GlobalScope.launch {
207
-                with (socket.await().getOutputStream()) {
208
-                    write("Hack the planet!".toByteArray())
209
-                    flush()
206
+                GlobalScope.launch {
207
+                    with(socket.await().getOutputStream()) {
208
+                        write("Hack the planet!".toByteArray())
209
+                        flush()
210
+                    }
210 211
                 }
211
-            }
212 212
 
213
-            val buffer = tlsSocket.read()
213
+                val buffer = tlsSocket.read()
214 214
 
215
-            assertNotNull(buffer)
216
-            buffer?.let {
217
-                assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
215
+                assertNotNull(buffer)
216
+                buffer?.let {
217
+                    assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
218
+                }
218 219
             }
219 220
         }
220 221
     }
221 222
 
222 223
     @Test
223 224
     fun `read returns null after close`() = runBlocking {
224
-        tlsServerSocket(12321).use { serverSocket ->
225
-            val plainSocket = PlainTextSocket(GlobalScope)
226
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
227
-            GlobalScope.launch {
228
-                serverSocket.accept().getInputStream().read()
229
-            }
225
+        withTimeout(5000) {
226
+            tlsServerSocket(12321).use { serverSocket ->
227
+                val plainSocket = PlainTextSocket(GlobalScope)
228
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
229
+                GlobalScope.launch {
230
+                    serverSocket.accept().getInputStream().read()
231
+                }
230 232
 
231
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
233
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
232 234
 
233
-            tlsSocket.close()
235
+                tlsSocket.close()
234 236
 
235
-            val buffer = tlsSocket.read()
237
+                val buffer = tlsSocket.read()
236 238
 
237
-            assertNull(buffer)
239
+                assertNull(buffer)
240
+            }
238 241
         }
239 242
     }
240 243
 
@@ -248,11 +251,13 @@ internal class TlsSocketTest {
248 251
             }
249 252
 
250 253
             runBlocking {
251
-                try {
252
-                    tlsSocket.connect(InetSocketAddress("localhost", 12321))
253
-                    fail<Unit>("Expected an exception")
254
-                } catch (ex: Exception) {
255
-                    assertTrue(ex is CertificateException)
254
+                withTimeout(5000) {
255
+                    try {
256
+                        tlsSocket.connect(InetSocketAddress("localhost", 12321))
257
+                        fail<Unit>("Expected an exception")
258
+                    } catch (ex: Exception) {
259
+                        assertTrue(ex is CertificateException)
260
+                    }
256 261
                 }
257 262
             }
258 263
         }

Loading…
Cancel
Save