Browse Source

Add timeout to TLS tests

tags/v1.1.0
Chris Smith 5 years ago
parent
commit
8f27a42b4e
1 changed files with 54 additions and 49 deletions
  1. 54
    49
      src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt

+ 54
- 49
src/test/kotlin/com/dmdirc/ktirc/io/TlsTest.kt View File

2
 
2
 
3
 import io.mockk.every
3
 import io.mockk.every
4
 import io.mockk.mockk
4
 import io.mockk.mockk
5
-import kotlinx.coroutines.GlobalScope
6
-import kotlinx.coroutines.async
5
+import kotlinx.coroutines.*
7
 import kotlinx.coroutines.io.writeFully
6
 import kotlinx.coroutines.io.writeFully
8
-import kotlinx.coroutines.launch
9
-import kotlinx.coroutines.runBlocking
10
 import kotlinx.io.core.String
7
 import kotlinx.io.core.String
11
 import org.junit.jupiter.api.Assertions.*
8
 import org.junit.jupiter.api.Assertions.*
12
 import org.junit.jupiter.api.Test
9
 import org.junit.jupiter.api.Test
170
 
167
 
171
     @Test
168
     @Test
172
     fun `can send a string to a server over TLS`() = runBlocking {
169
     fun `can send a string to a server over TLS`() = runBlocking {
173
-        tlsServerSocket(12321).use { serverSocket ->
174
-            val plainSocket = PlainTextSocket(GlobalScope)
175
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
176
-            val clientBytesAsync = GlobalScope.async {
177
-                ByteArray(13).apply {
178
-                    serverSocket.accept().getInputStream().read(this)
170
+        withTimeout(5000) {
171
+            tlsServerSocket(12321).use { serverSocket ->
172
+                val plainSocket = PlainTextSocket(GlobalScope)
173
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
174
+                val clientBytesAsync = GlobalScope.async {
175
+                    ByteArray(13).apply {
176
+                        serverSocket.accept().getInputStream().read(this)
177
+                    }
179
                 }
178
                 }
180
-            }
181
 
179
 
182
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
183
-            tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
180
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
181
+                tlsSocket.write.writeFully("Hello World\r\n".toByteArray())
184
 
182
 
185
-            val bytes = clientBytesAsync.await()
186
-            assertNotNull(bytes)
187
-            assertEquals("Hello World\r\n", String(bytes))
183
+                val bytes = clientBytesAsync.await()
184
+                assertNotNull(bytes)
185
+                assertEquals("Hello World\r\n", String(bytes))
186
+            }
188
         }
187
         }
189
     }
188
     }
190
 
189
 
191
     @Test
190
     @Test
192
     fun `can read a string from a server over TLS`() = runBlocking<Unit> {
191
     fun `can read a string from a server over TLS`() = runBlocking<Unit> {
193
-        tlsServerSocket(12321).use { serverSocket ->
194
-            val plainSocket = PlainTextSocket(GlobalScope)
195
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
-            val socket = GlobalScope.async {
197
-                serverSocket.accept().apply {
198
-                    GlobalScope.launch {
199
-                        getInputStream().read()
192
+        withTimeout(5000) {
193
+            tlsServerSocket(12321).use { serverSocket ->
194
+                val plainSocket = PlainTextSocket(GlobalScope)
195
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
196
+                val socket = GlobalScope.async {
197
+                    serverSocket.accept().apply {
198
+                        GlobalScope.launch {
199
+                            getInputStream().read()
200
+                        }
200
                     }
201
                     }
201
                 }
202
                 }
202
-            }
203
 
203
 
204
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
204
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
205
 
205
 
206
-            GlobalScope.launch {
207
-                with (socket.await().getOutputStream()) {
208
-                    write("Hack the planet!".toByteArray())
209
-                    flush()
206
+                GlobalScope.launch {
207
+                    with(socket.await().getOutputStream()) {
208
+                        write("Hack the planet!".toByteArray())
209
+                        flush()
210
+                    }
210
                 }
211
                 }
211
-            }
212
 
212
 
213
-            val buffer = tlsSocket.read()
213
+                val buffer = tlsSocket.read()
214
 
214
 
215
-            assertNotNull(buffer)
216
-            buffer?.let {
217
-                assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
215
+                assertNotNull(buffer)
216
+                buffer?.let {
217
+                    assertEquals("Hack the planet!", String(it.array(), 0, it.limit()))
218
+                }
218
             }
219
             }
219
         }
220
         }
220
     }
221
     }
221
 
222
 
222
     @Test
223
     @Test
223
     fun `read returns null after close`() = runBlocking {
224
     fun `read returns null after close`() = runBlocking {
224
-        tlsServerSocket(12321).use { serverSocket ->
225
-            val plainSocket = PlainTextSocket(GlobalScope)
226
-            val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
227
-            GlobalScope.launch {
228
-                serverSocket.accept().getInputStream().read()
229
-            }
225
+        withTimeout(5000) {
226
+            tlsServerSocket(12321).use { serverSocket ->
227
+                val plainSocket = PlainTextSocket(GlobalScope)
228
+                val tlsSocket = TlsSocket(GlobalScope, plainSocket, getTrustingContext(), "localhost")
229
+                GlobalScope.launch {
230
+                    serverSocket.accept().getInputStream().read()
231
+                }
230
 
232
 
231
-            tlsSocket.connect(InetSocketAddress("localhost", 12321))
233
+                tlsSocket.connect(InetSocketAddress("localhost", 12321))
232
 
234
 
233
-            tlsSocket.close()
235
+                tlsSocket.close()
234
 
236
 
235
-            val buffer = tlsSocket.read()
237
+                val buffer = tlsSocket.read()
236
 
238
 
237
-            assertNull(buffer)
239
+                assertNull(buffer)
240
+            }
238
         }
241
         }
239
     }
242
     }
240
 
243
 
248
             }
251
             }
249
 
252
 
250
             runBlocking {
253
             runBlocking {
251
-                try {
252
-                    tlsSocket.connect(InetSocketAddress("localhost", 12321))
253
-                    fail<Unit>("Expected an exception")
254
-                } catch (ex: Exception) {
255
-                    assertTrue(ex is CertificateException)
254
+                withTimeout(5000) {
255
+                    try {
256
+                        tlsSocket.connect(InetSocketAddress("localhost", 12321))
257
+                        fail<Unit>("Expected an exception")
258
+                    } catch (ex: Exception) {
259
+                        assertTrue(ex is CertificateException)
260
+                    }
256
                 }
261
                 }
257
             }
262
             }
258
         }
263
         }

Loading…
Cancel
Save