Browse Source

Improve couroutines and channel use in the socket

tags/v0.6.0
Chris Smith 5 years ago
parent
commit
e64b705412

+ 1
- 0
CHANGELOG View File

11
     * Added sendTagMessage() to send message tags without any content
11
     * Added sendTagMessage() to send message tags without any content
12
     * The reply() utility automatically marks messages as a reply
12
     * The reply() utility automatically marks messages as a reply
13
     * Added react() utility to send a reaction client tag
13
     * Added react() utility to send a reaction client tag
14
+ * (Internal) improved how coroutines and channels are used in LineBufferedSocket
14
 
15
 
15
 v0.5.0
16
 v0.5.0
16
 
17
 

+ 0
- 3
src/itest/kotlin/com/dmdirc/ktirc/KtIrcIntegrationTest.kt View File

3
 import com.dmdirc.irctest.IrcLibraryTests
3
 import com.dmdirc.irctest.IrcLibraryTests
4
 import com.dmdirc.ktirc.model.Profile
4
 import com.dmdirc.ktirc.model.Profile
5
 import com.dmdirc.ktirc.model.Server
5
 import com.dmdirc.ktirc.model.Server
6
-import kotlinx.coroutines.delay
7
 import kotlinx.coroutines.runBlocking
6
 import kotlinx.coroutines.runBlocking
8
 import org.junit.jupiter.api.TestFactory
7
 import org.junit.jupiter.api.TestFactory
9
 
8
 
21
 
20
 
22
         override fun terminate() {
21
         override fun terminate() {
23
             runBlocking {
22
             runBlocking {
24
-                delay(100)
25
                 ircClient.disconnect()
23
                 ircClient.disconnect()
26
-                ircClient.join()
27
             }
24
             }
28
         }
25
         }
29
 
26
 

+ 15
- 33
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt View File

5
 import com.dmdirc.ktirc.messages.*
5
 import com.dmdirc.ktirc.messages.*
6
 import com.dmdirc.ktirc.model.*
6
 import com.dmdirc.ktirc.model.*
7
 import com.dmdirc.ktirc.util.currentTimeProvider
7
 import com.dmdirc.ktirc.util.currentTimeProvider
8
+import io.ktor.util.KtorExperimentalAPI
8
 import kotlinx.coroutines.*
9
 import kotlinx.coroutines.*
9
-import kotlinx.coroutines.channels.Channel
10
 import kotlinx.coroutines.channels.map
10
 import kotlinx.coroutines.channels.map
11
 import java.util.concurrent.atomic.AtomicBoolean
11
 import java.util.concurrent.atomic.AtomicBoolean
12
-import java.util.logging.Level
13
-import java.util.logging.LogManager
14
 
12
 
15
 /**
13
 /**
16
  * Primary interface for interacting with KtIrc.
14
  * Primary interface for interacting with KtIrc.
88
 // TODO: How should alternative nicknames work?
86
 // TODO: How should alternative nicknames work?
89
 // TODO: Should IRC Client take a pool of servers and rotate through, or make the caller do that?
87
 // TODO: Should IRC Client take a pool of servers and rotate through, or make the caller do that?
90
 // TODO: Should there be a default profile?
88
 // TODO: Should there be a default profile?
91
-class IrcClientImpl(private val server: Server, override val profile: Profile) : IrcClient {
89
+@KtorExperimentalAPI
90
+@ExperimentalCoroutinesApi
91
+class IrcClientImpl(private val server: Server, override val profile: Profile) : IrcClient, CoroutineScope {
92
 
92
 
93
-    internal var socketFactory: (String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
93
+    override val coroutineContext = GlobalScope.newCoroutineContext(Dispatchers.IO)
94
+
95
+    internal var socketFactory: (CoroutineScope, String, Int, Boolean) -> LineBufferedSocket = ::KtorLineBufferedSocket
94
 
96
 
95
     override val serverState = ServerState(profile.initialNick, server.host)
97
     override val serverState = ServerState(profile.initialNick, server.host)
96
     override val channelState = ChannelStateMap { caseMapping }
98
     override val channelState = ChannelStateMap { caseMapping }
101
     private val parser = MessageParser()
103
     private val parser = MessageParser()
102
     private var socket: LineBufferedSocket? = null
104
     private var socket: LineBufferedSocket? = null
103
 
105
 
104
-    private val scope = CoroutineScope(Dispatchers.IO)
105
     private val connecting = AtomicBoolean(false)
106
     private val connecting = AtomicBoolean(false)
106
 
107
 
107
-    private var connectionJob: Job? = null
108
-    internal var writeChannel: Channel<ByteArray>? = null
109
-
110
     override fun send(message: String) {
108
     override fun send(message: String) {
111
-        writeChannel?.offer(message.toByteArray())
109
+        socket?.sendChannel?.offer(message.toByteArray())
112
     }
110
     }
113
 
111
 
114
     override fun connect() {
112
     override fun connect() {
115
         check(!connecting.getAndSet(true))
113
         check(!connecting.getAndSet(true))
116
-        connectionJob = scope.launch {
117
-            with(socketFactory(server.host, server.port, server.tls)) {
118
-                // TODO: Proper error handling - what if connect() fails?
119
-                socket = this
120
 
114
 
121
-                emitEvent(ServerConnecting(currentTimeProvider()))
115
+        with(socketFactory(this, server.host, server.port, server.tls)) {
116
+            // TODO: Proper error handling - what if connect() fails?
117
+            socket = this
122
 
118
 
123
-                connect()
124
-
125
-                with(Channel<ByteArray>(Channel.UNLIMITED)) {
126
-                    writeChannel = this
127
-                    scope.launch {
128
-                        writeChannel?.let {
129
-                            writeLines(it)
130
-                        }
131
-                    }
132
-                }
119
+            emitEvent(ServerConnecting(currentTimeProvider()))
133
 
120
 
121
+            launch {
122
+                connect()
134
                 emitEvent(ServerConnected(currentTimeProvider()))
123
                 emitEvent(ServerConnected(currentTimeProvider()))
135
                 sendCapabilityList()
124
                 sendCapabilityList()
136
                 sendPasswordIfPresent()
125
                 sendPasswordIfPresent()
137
                 sendNickChange(profile.initialNick)
126
                 sendNickChange(profile.initialNick)
138
                 // TODO: Send correct host
127
                 // TODO: Send correct host
139
                 sendUser(profile.userName, profile.realName)
128
                 sendUser(profile.userName, profile.realName)
140
-                messageHandler.processMessages(this@IrcClientImpl, readLines(scope).map { parser.parse(it) })
129
+                messageHandler.processMessages(this@IrcClientImpl, receiveChannel.map { parser.parse(it) })
141
                 emitEvent(ServerDisconnected(currentTimeProvider()))
130
                 emitEvent(ServerDisconnected(currentTimeProvider()))
142
             }
131
             }
143
         }
132
         }
147
         socket?.disconnect()
136
         socket?.disconnect()
148
     }
137
     }
149
 
138
 
150
-    /**
151
-     * Joins the coroutine running the message loop, and blocks until it is completed.
152
-     */
153
-    suspend fun join() {
154
-        connectionJob?.join()
155
-    }
156
-
157
     override fun onEvent(handler: (IrcEvent) -> Unit) {
139
     override fun onEvent(handler: (IrcEvent) -> Unit) {
158
         messageHandler.handlers.add(object : EventHandler {
140
         messageHandler.handlers.add(object : EventHandler {
159
             override fun processEvent(client: IrcClient, event: IrcEvent): List<IrcEvent> {
141
             override fun processEvent(client: IrcClient, event: IrcEvent): List<IrcEvent> {

+ 46
- 36
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

7
 import io.ktor.network.sockets.openReadChannel
7
 import io.ktor.network.sockets.openReadChannel
8
 import io.ktor.network.sockets.openWriteChannel
8
 import io.ktor.network.sockets.openWriteChannel
9
 import io.ktor.network.tls.tls
9
 import io.ktor.network.tls.tls
10
-import kotlinx.coroutines.CoroutineScope
11
-import kotlinx.coroutines.Dispatchers
12
-import kotlinx.coroutines.ExperimentalCoroutinesApi
13
-import kotlinx.coroutines.GlobalScope
10
+import io.ktor.util.KtorExperimentalAPI
11
+import kotlinx.coroutines.*
12
+import kotlinx.coroutines.channels.Channel
14
 import kotlinx.coroutines.channels.ReceiveChannel
13
 import kotlinx.coroutines.channels.ReceiveChannel
14
+import kotlinx.coroutines.channels.SendChannel
15
 import kotlinx.coroutines.channels.produce
15
 import kotlinx.coroutines.channels.produce
16
 import kotlinx.coroutines.io.ByteReadChannel
16
 import kotlinx.coroutines.io.ByteReadChannel
17
 import kotlinx.coroutines.io.ByteWriteChannel
17
 import kotlinx.coroutines.io.ByteWriteChannel
21
 
21
 
22
 internal interface LineBufferedSocket {
22
 internal interface LineBufferedSocket {
23
 
23
 
24
-    suspend fun connect()
24
+    fun connect()
25
     fun disconnect()
25
     fun disconnect()
26
 
26
 
27
-    fun readLines(coroutineScope: CoroutineScope): ReceiveChannel<ByteArray>
28
-    suspend fun writeLines(channel: ReceiveChannel<ByteArray>)
27
+    val sendChannel: SendChannel<ByteArray>
28
+    val receiveChannel: ReceiveChannel<ByteArray>
29
 
29
 
30
 }
30
 }
31
 
31
 
33
  * Asynchronous socket that buffers incoming data and emits individual lines.
33
  * Asynchronous socket that buffers incoming data and emits individual lines.
34
  */
34
  */
35
 // TODO: Expose advanced TLS options
35
 // TODO: Expose advanced TLS options
36
-internal class KtorLineBufferedSocket(private val host: String, private val port: Int, private val tls: Boolean = false): LineBufferedSocket {
36
+@KtorExperimentalAPI
37
+@ExperimentalCoroutinesApi
38
+internal class KtorLineBufferedSocket(coroutineScope: CoroutineScope, private val host: String, private val port: Int, private val tls: Boolean = false) : CoroutineScope, LineBufferedSocket {
37
 
39
 
38
     companion object {
40
     companion object {
39
         const val CARRIAGE_RETURN = '\r'.toByte()
41
         const val CARRIAGE_RETURN = '\r'.toByte()
40
         const val LINE_FEED = '\n'.toByte()
42
         const val LINE_FEED = '\n'.toByte()
41
     }
43
     }
42
 
44
 
45
+    override val coroutineContext = coroutineScope.newCoroutineContext(Dispatchers.IO)
46
+    override val sendChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED)
47
+
43
     var tlsTrustManager: X509TrustManager? = null
48
     var tlsTrustManager: X509TrustManager? = null
44
 
49
 
45
     private val log by logger()
50
     private val log by logger()
48
     private lateinit var readChannel: ByteReadChannel
53
     private lateinit var readChannel: ByteReadChannel
49
     private lateinit var writeChannel: ByteWriteChannel
54
     private lateinit var writeChannel: ByteWriteChannel
50
 
55
 
51
-    @Suppress("EXPERIMENTAL_API_USAGE")
52
-    override suspend fun connect() {
53
-        log.info { "Connecting..." }
54
-        socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
55
-        if (tls) {
56
-            // TODO: Figure out how exactly scopes work...
57
-            socket = socket.tls(GlobalScope.coroutineContext, randomAlgorithm = SecureRandom.getInstanceStrong().algorithm, trustManager = tlsTrustManager)
56
+    override fun connect() {
57
+        runBlocking {
58
+            log.info { "Connecting..." }
59
+            socket = aSocket(ActorSelectorManager(Dispatchers.IO)).tcp().connect(InetSocketAddress(host, port))
60
+            if (tls) {
61
+                socket = socket.tls(
62
+                        coroutineContext = this@KtorLineBufferedSocket.coroutineContext,
63
+                        randomAlgorithm = SecureRandom.getInstanceStrong().algorithm,
64
+                        trustManager = tlsTrustManager)
65
+            }
66
+            readChannel = socket.openReadChannel()
67
+            writeChannel = socket.openWriteChannel()
58
         }
68
         }
59
-        readChannel = socket.openReadChannel()
60
-        writeChannel = socket.openWriteChannel()
69
+        launch { writeLines() }
61
     }
70
     }
62
 
71
 
63
     override fun disconnect() {
72
     override fun disconnect() {
64
         log.info { "Disconnecting..." }
73
         log.info { "Disconnecting..." }
65
         socket.close()
74
         socket.close()
75
+        coroutineContext.cancel()
66
     }
76
     }
67
 
77
 
68
-    @ExperimentalCoroutinesApi
69
-    override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
70
-        val lineBuffer = ByteArray(4096)
71
-        var index = 0
72
-        while (!readChannel.isClosedForRead) {
73
-            var start = index
74
-            val count = readChannel.readAvailable(lineBuffer, index, lineBuffer.size - index)
75
-            for (i in index until index + count) {
76
-                if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
77
-                    if (start < i) {
78
-                        val line = lineBuffer.sliceArray(start until i)
79
-                        log.fine { "<<< ${String(line)}" }
80
-                        send(line)
78
+    override val receiveChannel
79
+        get() = produce {
80
+            val lineBuffer = ByteArray(4096)
81
+            var index = 0
82
+            while (!readChannel.isClosedForRead) {
83
+                var start = index
84
+                val count = readChannel.readAvailable(lineBuffer, index, lineBuffer.size - index)
85
+                for (i in index until index + count) {
86
+                    if (lineBuffer[i] == CARRIAGE_RETURN || lineBuffer[i] == LINE_FEED) {
87
+                        if (start < i) {
88
+                            val line = lineBuffer.sliceArray(start until i)
89
+                            log.fine { "<<< ${String(line)}" }
90
+                            send(line)
91
+                        }
92
+                        start = i + 1
81
                     }
93
                     }
82
-                    start = i + 1
83
                 }
94
                 }
95
+                lineBuffer.copyInto(lineBuffer, 0, start)
96
+                index = count + index - start
84
             }
97
             }
85
-            lineBuffer.copyInto(lineBuffer, 0, start)
86
-            index = count + index - start
87
         }
98
         }
88
-    }
89
 
99
 
90
-    override suspend fun writeLines(channel: ReceiveChannel<ByteArray>) {
91
-        for (line in channel) {
100
+    private suspend fun writeLines() {
101
+        for (line in sendChannel) {
92
             with(writeChannel) {
102
             with(writeChannel) {
93
                 log.fine { ">>> ${String(line)}" }
103
                 log.fine { ">>> ${String(line)}" }
94
                 writeAvailable(line, 0, line.size)
104
                 writeAvailable(line, 0, line.size)

+ 25
- 61
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt View File

3
 import com.dmdirc.ktirc.events.*
3
 import com.dmdirc.ktirc.events.*
4
 import com.dmdirc.ktirc.io.CaseMapping
4
 import com.dmdirc.ktirc.io.CaseMapping
5
 import com.dmdirc.ktirc.io.LineBufferedSocket
5
 import com.dmdirc.ktirc.io.LineBufferedSocket
6
-import com.dmdirc.ktirc.model.Profile
7
-import com.dmdirc.ktirc.model.Server
8
-import com.dmdirc.ktirc.model.ServerFeature
9
-import com.dmdirc.ktirc.model.User
6
+import com.dmdirc.ktirc.model.*
10
 import com.dmdirc.ktirc.util.currentTimeProvider
7
 import com.dmdirc.ktirc.util.currentTimeProvider
11
 import com.nhaarman.mockitokotlin2.*
8
 import com.nhaarman.mockitokotlin2.*
9
+import io.ktor.util.KtorExperimentalAPI
12
 import kotlinx.coroutines.*
10
 import kotlinx.coroutines.*
13
 import kotlinx.coroutines.channels.Channel
11
 import kotlinx.coroutines.channels.Channel
12
+import kotlinx.coroutines.channels.filter
13
+import kotlinx.coroutines.channels.map
14
 import org.junit.jupiter.api.Assertions.*
14
 import org.junit.jupiter.api.Assertions.*
15
 import org.junit.jupiter.api.BeforeEach
15
 import org.junit.jupiter.api.BeforeEach
16
 import org.junit.jupiter.api.Test
16
 import org.junit.jupiter.api.Test
17
 import org.junit.jupiter.api.assertThrows
17
 import org.junit.jupiter.api.assertThrows
18
 
18
 
19
+@KtorExperimentalAPI
20
+@ExperimentalCoroutinesApi
19
 internal class IrcClientImplTest {
21
 internal class IrcClientImplTest {
20
 
22
 
21
     companion object {
23
     companion object {
27
         private const val PASSWORD = "HackThePlanet"
29
         private const val PASSWORD = "HackThePlanet"
28
     }
30
     }
29
 
31
 
30
-    private val readLineChannel = Channel<ByteArray>(10)
32
+    private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
33
+    private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
31
 
34
 
32
     private val mockSocket = mock<LineBufferedSocket> {
35
     private val mockSocket = mock<LineBufferedSocket> {
33
-        on { readLines(any()) } doReturn readLineChannel
36
+        on { receiveChannel } doReturn readLineChannel
37
+        on { sendChannel } doReturn sendLineChannel
34
     }
38
     }
35
 
39
 
36
-    private val mockSocketFactory = mock<(String, Int, Boolean) -> LineBufferedSocket> {
37
-        on { invoke(eq(HOST), eq(PORT), any()) } doReturn mockSocket
40
+    private val mockSocketFactory = mock<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
41
+        on { invoke(any(), eq(HOST), eq(PORT), any()) } doReturn mockSocket
38
     }
42
     }
39
 
43
 
40
     private val mockEventHandler = mock<(IrcEvent) -> Unit>()
44
     private val mockEventHandler = mock<(IrcEvent) -> Unit>()
50
         client.socketFactory = mockSocketFactory
54
         client.socketFactory = mockSocketFactory
51
         client.connect()
55
         client.connect()
52
 
56
 
53
-        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, false)
57
+        verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, false)
54
     }
58
     }
55
 
59
 
56
     @Test
60
     @Test
59
         client.socketFactory = mockSocketFactory
63
         client.socketFactory = mockSocketFactory
60
         client.connect()
64
         client.connect()
61
 
65
 
62
-        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, true)
66
+        verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, true)
63
     }
67
     }
64
 
68
 
65
     @Test
69
     @Test
115
 
119
 
116
         client.blockUntilConnected()
120
         client.blockUntilConnected()
117
 
121
 
118
-        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
119
-        assertEquals("NICK :$NICK", String(client.writeChannel!!.receive()))
120
-        assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(client.writeChannel!!.receive()))
122
+        assertEquals("CAP LS 302", String(sendLineChannel.receive()))
123
+        assertEquals("NICK :$NICK", String(sendLineChannel.receive()))
124
+        assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
121
     }
125
     }
122
 
126
 
123
     @Test
127
     @Test
128
 
132
 
129
         client.blockUntilConnected()
133
         client.blockUntilConnected()
130
 
134
 
131
-        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
132
-        assertEquals("PASS :$PASSWORD", String(client.writeChannel!!.receive()))
135
+        assertEquals("CAP LS 302", String(sendLineChannel.receive()))
136
+        assertEquals("PASS :$PASSWORD", String(sendLineChannel.receive()))
133
     }
137
     }
134
 
138
 
135
     @Test
139
     @Test
180
         assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
184
         assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
181
     }
185
     }
182
 
186
 
183
-    @Test
184
-    fun `IrcClientImpl join blocks when socket is open`() {
185
-        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
186
-        client.socketFactory = mockSocketFactory
187
-
188
-        GlobalScope.launch {
189
-            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
190
-        }
191
-
192
-        client.connect()
193
-        runBlocking {
194
-            assertNull(withTimeoutOrNull(100L) {
195
-                client.join()
196
-                true
197
-            })
198
-        }
199
-    }
200
-
201
-    @Test
202
-    fun `IrcClientImpl join returns when socket is closed`() {
203
-        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
204
-        client.socketFactory = mockSocketFactory
205
-
206
-        GlobalScope.launch {
207
-            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
208
-            readLineChannel.close()
209
-        }
210
-
211
-        client.connect()
212
-        runBlocking {
213
-            assertEquals(true, withTimeoutOrNull(500L) {
214
-                client.join()
215
-                true
216
-            })
217
-        }
218
-    }
219
-
220
     @Test
187
     @Test
221
     fun `IrcClientImpl sends text to socket`() = runBlocking {
188
     fun `IrcClientImpl sends text to socket`() = runBlocking {
222
         val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
189
         val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
229
 
196
 
230
         assertEquals(true, withTimeoutOrNull(500) {
197
         assertEquals(true, withTimeoutOrNull(500) {
231
             var found = false
198
             var found = false
232
-            for (line in client.writeChannel!!) {
199
+            for (line in sendLineChannel) {
233
                 if (String(line) == "testing 123") {
200
                 if (String(line) == "testing 123") {
234
                     found = true
201
                     found = true
235
                     break
202
                     break
264
 
231
 
265
         assertEquals(100, withTimeoutOrNull(500) {
232
         assertEquals(100, withTimeoutOrNull(500) {
266
             var next = 0
233
             var next = 0
267
-            for (line in client.writeChannel!!) {
268
-                val stringy = String(line)
269
-                if (stringy.startsWith("TEST ")) {
270
-                    assertEquals("TEST $next", stringy)
271
-                    if (++next == 100) {
272
-                        break
273
-                    }
234
+            for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
235
+                assertEquals("TEST $next", line)
236
+                if (++next == 100) {
237
+                    break
274
                 }
238
                 }
275
             }
239
             }
276
             next
240
             next
291
 
255
 
292
     private suspend fun IrcClientImpl.blockUntilConnected() {
256
     private suspend fun IrcClientImpl.blockUntilConnected() {
293
         // Yuck. Maybe connect should be asynchronous?
257
         // Yuck. Maybe connect should be asynchronous?
294
-        while (writeChannel == null) {
258
+        while (serverState.status <= ServerStatus.Connecting) {
295
             delay(50)
259
             delay(50)
296
         }
260
         }
297
     }
261
     }

+ 20
- 25
src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt View File

1
 package com.dmdirc.ktirc.io
1
 package com.dmdirc.ktirc.io
2
 
2
 
3
 import io.ktor.network.tls.certificates.generateCertificate
3
 import io.ktor.network.tls.certificates.generateCertificate
4
-import kotlinx.coroutines.GlobalScope
5
-import kotlinx.coroutines.async
6
-import kotlinx.coroutines.channels.Channel
7
-import kotlinx.coroutines.launch
8
-import kotlinx.coroutines.runBlocking
4
+import io.ktor.util.KtorExperimentalAPI
5
+import kotlinx.coroutines.*
9
 import org.junit.jupiter.api.Assertions.assertEquals
6
 import org.junit.jupiter.api.Assertions.assertEquals
10
 import org.junit.jupiter.api.Assertions.assertNotNull
7
 import org.junit.jupiter.api.Assertions.assertNotNull
11
 import org.junit.jupiter.api.Test
8
 import org.junit.jupiter.api.Test
19
 import javax.net.ssl.SSLContext
16
 import javax.net.ssl.SSLContext
20
 import javax.net.ssl.X509TrustManager
17
 import javax.net.ssl.X509TrustManager
21
 
18
 
19
+@KtorExperimentalAPI
20
+@ExperimentalCoroutinesApi
22
 @Execution(ExecutionMode.SAME_THREAD)
21
 @Execution(ExecutionMode.SAME_THREAD)
23
 internal class KtorLineBufferedSocketTest {
22
 internal class KtorLineBufferedSocketTest {
24
 
23
 
25
-    private val writeChannel = Channel<ByteArray>(Channel.UNLIMITED)
26
-
27
     @Test
24
     @Test
28
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
25
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
29
         ServerSocket(12321).use { serverSocket ->
26
         ServerSocket(12321).use { serverSocket ->
30
-            val socket = KtorLineBufferedSocket("localhost", 12321)
27
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
31
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
28
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
32
 
29
 
33
             socket.connect()
30
             socket.connect()
39
     @Test
36
     @Test
40
     fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
37
     fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
41
         ServerSocket(12321).use { serverSocket ->
38
         ServerSocket(12321).use { serverSocket ->
42
-            val socket = KtorLineBufferedSocket("localhost", 12321)
39
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
43
             val clientBytesAsync = GlobalScope.async {
40
             val clientBytesAsync = GlobalScope.async {
44
                 ByteArray(13).apply {
41
                 ByteArray(13).apply {
45
                     serverSocket.accept().getInputStream().read(this)
42
                     serverSocket.accept().getInputStream().read(this)
47
             }
44
             }
48
 
45
 
49
             socket.connect()
46
             socket.connect()
50
-            GlobalScope.launch { socket.writeLines(writeChannel) }
51
-            writeChannel.send("Hello World".toByteArray())
47
+            socket.sendChannel.send("Hello World".toByteArray())
52
 
48
 
53
             val bytes = clientBytesAsync.await()
49
             val bytes = clientBytesAsync.await()
54
             assertNotNull(bytes)
50
             assertNotNull(bytes)
59
     @Test
55
     @Test
60
     fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
56
     fun `KtorLineBufferedSocket can send a string to a server over TLS`() = runBlocking {
61
         tlsServerSocket(12321).use { serverSocket ->
57
         tlsServerSocket(12321).use { serverSocket ->
62
-            val socket = KtorLineBufferedSocket("localhost", 12321, true)
58
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321, true)
63
             socket.tlsTrustManager = getTrustingManager()
59
             socket.tlsTrustManager = getTrustingManager()
64
             val clientBytesAsync = GlobalScope.async {
60
             val clientBytesAsync = GlobalScope.async {
65
                 ByteArray(13).apply {
61
                 ByteArray(13).apply {
68
             }
64
             }
69
 
65
 
70
             socket.connect()
66
             socket.connect()
71
-            GlobalScope.launch { socket.writeLines(writeChannel) }
72
-            writeChannel.send("Hello World".toByteArray())
67
+            socket.sendChannel.send("Hello World".toByteArray())
73
 
68
 
74
             val bytes = clientBytesAsync.await()
69
             val bytes = clientBytesAsync.await()
75
             assertNotNull(bytes)
70
             assertNotNull(bytes)
80
     @Test
75
     @Test
81
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
76
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
82
         ServerSocket(12321).use { serverSocket ->
77
         ServerSocket(12321).use { serverSocket ->
83
-            val socket = KtorLineBufferedSocket("localhost", 12321)
78
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
84
             GlobalScope.launch {
79
             GlobalScope.launch {
85
                 serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray())
80
                 serverSocket.accept().getOutputStream().write("Hi there\r\n".toByteArray())
86
             }
81
             }
87
 
82
 
88
             socket.connect()
83
             socket.connect()
89
-            assertEquals("Hi there", String(socket.readLines(GlobalScope).receive()))
84
+            assertEquals("Hi there", String(socket.receiveChannel.receive()))
90
         }
85
         }
91
     }
86
     }
92
 
87
 
93
     @Test
88
     @Test
94
     fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
89
     fun `KtorLineBufferedSocket can receive a line of LF delimited text`() = runBlocking {
95
         ServerSocket(12321).use { serverSocket ->
90
         ServerSocket(12321).use { serverSocket ->
96
-            val socket = KtorLineBufferedSocket("localhost", 12321)
91
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
97
             GlobalScope.launch {
92
             GlobalScope.launch {
98
                 serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray())
93
                 serverSocket.accept().getOutputStream().write("Hi there\n".toByteArray())
99
             }
94
             }
100
 
95
 
101
             socket.connect()
96
             socket.connect()
102
-            assertEquals("Hi there", String(socket.readLines(GlobalScope).receive()))
97
+            assertEquals("Hi there", String(socket.receiveChannel.receive()))
103
         }
98
         }
104
     }
99
     }
105
 
100
 
106
     @Test
101
     @Test
107
     fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
102
     fun `KtorLineBufferedSocket can receive multiple lines of text in one packet`() = runBlocking {
108
         ServerSocket(12321).use { serverSocket ->
103
         ServerSocket(12321).use { serverSocket ->
109
-            val socket = KtorLineBufferedSocket("localhost", 12321)
104
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
110
             GlobalScope.launch {
105
             GlobalScope.launch {
111
                 serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray())
106
                 serverSocket.accept().getOutputStream().write("Hi there\nThis is a test\r".toByteArray())
112
             }
107
             }
113
 
108
 
114
             socket.connect()
109
             socket.connect()
115
-            val lineProducer = socket.readLines(GlobalScope)
110
+            val lineProducer = socket.receiveChannel
116
             assertEquals("Hi there", String(lineProducer.receive()))
111
             assertEquals("Hi there", String(lineProducer.receive()))
117
             assertEquals("This is a test", String(lineProducer.receive()))
112
             assertEquals("This is a test", String(lineProducer.receive()))
118
         }
113
         }
121
     @Test
116
     @Test
122
     fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
117
     fun `KtorLineBufferedSocket can receive one line of text over multiple packets`() = runBlocking {
123
         ServerSocket(12321).use { serverSocket ->
118
         ServerSocket(12321).use { serverSocket ->
124
-            val socket = KtorLineBufferedSocket("localhost", 12321)
119
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
125
             GlobalScope.launch {
120
             GlobalScope.launch {
126
                 with(serverSocket.accept().getOutputStream()) {
121
                 with(serverSocket.accept().getOutputStream()) {
127
                     write("Hi".toByteArray())
122
                     write("Hi".toByteArray())
134
             }
129
             }
135
 
130
 
136
             socket.connect()
131
             socket.connect()
137
-            val lineProducer = socket.readLines(GlobalScope)
132
+            val lineProducer = socket.receiveChannel
138
             assertEquals("Hi there", String(lineProducer.receive()))
133
             assertEquals("Hi there", String(lineProducer.receive()))
139
         }
134
         }
140
     }
135
     }
142
     @Test
137
     @Test
143
     fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
138
     fun `KtorLineBufferedSocket returns from readLines when socket is closed`() = runBlocking {
144
         ServerSocket(12321).use { serverSocket ->
139
         ServerSocket(12321).use { serverSocket ->
145
-            val socket = KtorLineBufferedSocket("localhost", 12321)
140
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
146
             GlobalScope.launch {
141
             GlobalScope.launch {
147
                 with(serverSocket.accept()) {
142
                 with(serverSocket.accept()) {
148
                     getOutputStream().write("Hi there\r\n".toByteArray())
143
                     getOutputStream().write("Hi there\r\n".toByteArray())
151
             }
146
             }
152
 
147
 
153
             socket.connect()
148
             socket.connect()
154
-            val lineProducer = socket.readLines(GlobalScope)
149
+            val lineProducer = socket.receiveChannel
155
             assertEquals("Hi there", String(lineProducer.receive()))
150
             assertEquals("Hi there", String(lineProducer.receive()))
156
         }
151
         }
157
     }
152
     }
159
     @Test
154
     @Test
160
     fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
155
     fun `KtorLineBufferedSocket disconnects from server`() = runBlocking {
161
         ServerSocket(12321).use { serverSocket ->
156
         ServerSocket(12321).use { serverSocket ->
162
-            val socket = KtorLineBufferedSocket("localhost", 12321)
157
+            val socket = KtorLineBufferedSocket(GlobalScope, "localhost", 12321)
163
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
158
             val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
164
 
159
 
165
             socket.connect()
160
             socket.connect()

Loading…
Cancel
Save