Browse Source

Rework how writing lines works, again.

To preserve the order of lines we offer them into a channel
without suspending. If we launch a new coroutine there's no
guarantee they'll execute in order, even if the method they
call is fair.
tags/v0.4.0
Chris Smith 5 years ago
parent
commit
74dd02ca3a

+ 1
- 0
CHANGELOG View File

@@ -2,6 +2,7 @@ vNEXT (in development)
2 2
 
3 3
  * Added CtcpReceived and ActionReceived events
4 4
  * Added sendCtcp and sendAction message builders
5
+ * Fix issue with messages being sent out of order, which sometimes caused problems connecting to passworded servers
5 6
 
6 7
 v0.3.1
7 8
 

+ 10
- 4
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt View File

@@ -6,6 +6,7 @@ import com.dmdirc.ktirc.messages.*
6 6
 import com.dmdirc.ktirc.model.*
7 7
 import com.dmdirc.ktirc.util.currentTimeProvider
8 8
 import kotlinx.coroutines.*
9
+import kotlinx.coroutines.channels.Channel
9 10
 import kotlinx.coroutines.channels.map
10 11
 import java.util.concurrent.atomic.AtomicBoolean
11 12
 import java.util.logging.Level
@@ -98,12 +99,10 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
98 99
     private val connecting = AtomicBoolean(false)
99 100
 
100 101
     private var connectionJob: Job? = null
102
+    internal var writeChannel: Channel<ByteArray>? = null
101 103
 
102 104
     override fun send(message: String) {
103
-        // TODO: What happens if sending fails?
104
-        scope.launch {
105
-            socket?.sendLine(message)
106
-        }
105
+        writeChannel?.offer(message.toByteArray())
107 106
     }
108 107
 
109 108
     override fun connect() {
@@ -112,7 +111,14 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
112 111
             with(socketFactory(server.host, server.port, server.tls)) {
113 112
                 // TODO: Proper error handling - what if connect() fails?
114 113
                 socket = this
114
+
115 115
                 connect()
116
+
117
+                with (Channel<ByteArray>(Channel.UNLIMITED)) {
118
+                    writeChannel = this
119
+                    scope.launch { writeLines(this@with) }
120
+                }
121
+
116 122
                 emitEvent(ServerConnected(currentTimeProvider()))
117 123
                 sendCapabilityList()
118 124
                 sendPasswordIfPresent()

+ 13
- 22
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -15,7 +15,6 @@ import kotlinx.coroutines.channels.ReceiveChannel
15 15
 import kotlinx.coroutines.channels.produce
16 16
 import kotlinx.coroutines.io.ByteReadChannel
17 17
 import kotlinx.coroutines.io.ByteWriteChannel
18
-import kotlinx.coroutines.sync.Mutex
19 18
 import java.net.InetSocketAddress
20 19
 import java.security.SecureRandom
21 20
 import javax.net.ssl.X509TrustManager
@@ -25,10 +24,8 @@ internal interface LineBufferedSocket {
25 24
     suspend fun connect()
26 25
     fun disconnect()
27 26
 
28
-    suspend fun sendLine(line: ByteArray, offset: Int = 0, length: Int = line.size)
29
-    suspend fun sendLine(line: String)
30
-
31 27
     fun readLines(coroutineScope: CoroutineScope): ReceiveChannel<ByteArray>
28
+    suspend fun writeLines(channel: ReceiveChannel<ByteArray>)
32 29
 
33 30
 }
34 31
 
@@ -46,7 +43,6 @@ internal class KtorLineBufferedSocket(private val host: String, private val port
46 43
     var tlsTrustManager: X509TrustManager? = null
47 44
 
48 45
     private val log by logger()
49
-    private val writeLock = Mutex()
50 46
 
51 47
     private lateinit var socket: Socket
52 48
     private lateinit var readChannel: ByteReadChannel
@@ -69,23 +65,6 @@ internal class KtorLineBufferedSocket(private val host: String, private val port
69 65
         socket.close()
70 66
     }
71 67
 
72
-    override suspend fun sendLine(line: ByteArray, offset: Int, length: Int) {
73
-        writeLock.lock()
74
-        try {
75
-            with(writeChannel) {
76
-                log.fine { ">>> ${String(line, offset, length)}" }
77
-                writeAvailable(line, offset, length)
78
-                writeByte(CARRIAGE_RETURN)
79
-                writeByte(LINE_FEED)
80
-                flush()
81
-            }
82
-        } finally {
83
-            writeLock.unlock()
84
-        }
85
-    }
86
-
87
-    override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
88
-
89 68
     @ExperimentalCoroutinesApi
90 69
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
91 70
         val lineBuffer = ByteArray(4096)
@@ -107,4 +86,16 @@ internal class KtorLineBufferedSocket(private val host: String, private val port
107 86
             index = count + index - start
108 87
         }
109 88
     }
89
+
90
+    override suspend fun writeLines(channel: ReceiveChannel<ByteArray>) {
91
+        for (line in channel) {
92
+            with(writeChannel) {
93
+                log.fine { ">>> ${String(line)}" }
94
+                writeAvailable(line, 0, line.size)
95
+                writeByte(CARRIAGE_RETURN)
96
+                writeByte(LINE_FEED)
97
+                flush()
98
+            }
99
+        }
100
+    }
110 101
 }

+ 54
- 19
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt View File

@@ -11,11 +11,8 @@ import com.dmdirc.ktirc.model.ServerFeature
11 11
 import com.dmdirc.ktirc.model.User
12 12
 import com.dmdirc.ktirc.util.currentTimeProvider
13 13
 import com.nhaarman.mockitokotlin2.*
14
-import kotlinx.coroutines.GlobalScope
14
+import kotlinx.coroutines.*
15 15
 import kotlinx.coroutines.channels.Channel
16
-import kotlinx.coroutines.launch
17
-import kotlinx.coroutines.runBlocking
18
-import kotlinx.coroutines.withTimeoutOrNull
19 16
 import org.junit.jupiter.api.Assertions.*
20 17
 import org.junit.jupiter.api.BeforeEach
21 18
 import org.junit.jupiter.api.Test
@@ -97,11 +94,11 @@ internal class IrcClientImplTest {
97 94
         client.socketFactory = mockSocketFactory
98 95
         client.connect()
99 96
 
100
-        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
101
-            sendLine("CAP LS 302")
102
-            sendLine("NICK :$NICK")
103
-            sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
104
-        }
97
+        client.blockUntilConnected()
98
+
99
+        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
100
+        assertEquals("NICK :$NICK", String(client.writeChannel!!.receive()))
101
+        assertEquals("USER $USER_NAME localhost $HOST :$REAL_NAME", String(client.writeChannel!!.receive()))
105 102
     }
106 103
 
107 104
     @Test
@@ -110,11 +107,10 @@ internal class IrcClientImplTest {
110 107
         client.socketFactory = mockSocketFactory
111 108
         client.connect()
112 109
 
113
-        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
114
-            sendLine("CAP LS 302")
115
-            sendLine("PASS :$PASSWORD")
116
-            sendLine("NICK :$NICK")
117
-        }
110
+        client.blockUntilConnected()
111
+
112
+        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
113
+        assertEquals("PASS :$PASSWORD", String(client.writeChannel!!.receive()))
118 114
     }
119 115
 
120 116
     @Test
@@ -199,12 +195,20 @@ internal class IrcClientImplTest {
199 195
         client.socketFactory = mockSocketFactory
200 196
         client.connect()
201 197
 
202
-        // Wait for it to connect
203
-        verify(mockSocket, timeout(500)).sendLine("CAP LS 302")
198
+        client.blockUntilConnected()
204 199
 
205 200
         client.send("testing 123")
206 201
 
207
-        verify(mockSocket, timeout(500)).sendLine("testing 123")
202
+        assertEquals(true, withTimeoutOrNull(500) {
203
+            var found = false
204
+            for (line in client.writeChannel!!) {
205
+                if (String(line) == "testing 123") {
206
+                    found = true
207
+                    break
208
+                }
209
+            }
210
+            found
211
+        })
208 212
     }
209 213
 
210 214
     @Test
@@ -213,13 +217,44 @@ internal class IrcClientImplTest {
213 217
         client.socketFactory = mockSocketFactory
214 218
         client.connect()
215 219
 
216
-        // Wait for it to connect
217
-        verify(mockSocket, timeout(500)).sendLine("CAP LS 302")
220
+        client.blockUntilConnected()
218 221
 
219 222
         client.disconnect()
220 223
 
221 224
         verify(mockSocket, timeout(500)).disconnect()
222 225
     }
223 226
 
227
+    @Test
228
+    fun `IrcClientImpl sends messages in order`() = runBlocking {
229
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
230
+        client.socketFactory = mockSocketFactory
231
+        client.connect()
232
+
233
+        client.blockUntilConnected()
234
+
235
+        (0..100).forEach { client.send("TEST $it") }
236
+
237
+        assertEquals(100, withTimeoutOrNull(500) {
238
+            var next = 0
239
+            for (line in client.writeChannel!!) {
240
+                val stringy = String(line)
241
+                if (stringy.startsWith("TEST ")) {
242
+                    assertEquals("TEST $next", stringy)
243
+                    if (++next == 100) {
244
+                        break
245
+                    }
246
+                }
247
+            }
248
+            next
249
+        })
250
+    }
251
+
252
+    private suspend fun IrcClientImpl.blockUntilConnected() {
253
+        // Yuck. Maybe connect should be asynchronous?
254
+        while (writeChannel == null) {
255
+            delay(50)
256
+        }
257
+    }
258
+
224 259
 
225 260
 }

+ 10
- 59
src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt View File

@@ -3,16 +3,17 @@ package com.dmdirc.ktirc.io
3 3
 import io.ktor.network.tls.certificates.generateCertificate
4 4
 import kotlinx.coroutines.GlobalScope
5 5
 import kotlinx.coroutines.async
6
+import kotlinx.coroutines.channels.Channel
6 7
 import kotlinx.coroutines.launch
7 8
 import kotlinx.coroutines.runBlocking
8
-import org.junit.jupiter.api.Assertions.*
9
+import org.junit.jupiter.api.Assertions.assertEquals
10
+import org.junit.jupiter.api.Assertions.assertNotNull
9 11
 import org.junit.jupiter.api.Test
10 12
 import org.junit.jupiter.api.parallel.Execution
11 13
 import org.junit.jupiter.api.parallel.ExecutionMode
12 14
 import java.io.File
13 15
 import java.net.ServerSocket
14 16
 import java.security.KeyStore
15
-import java.security.cert.CertificateException
16 17
 import java.security.cert.X509Certificate
17 18
 import javax.net.ssl.KeyManagerFactory
18 19
 import javax.net.ssl.SSLContext
@@ -21,6 +22,8 @@ import javax.net.ssl.X509TrustManager
21 22
 @Execution(ExecutionMode.SAME_THREAD)
22 23
 internal class KtorLineBufferedSocketTest {
23 24
 
25
+    private val writeChannel = Channel<ByteArray>(Channel.UNLIMITED)
26
+
24 27
     @Test
25 28
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
26 29
         ServerSocket(12321).use { serverSocket ->
@@ -34,23 +37,7 @@ internal class KtorLineBufferedSocketTest {
34 37
     }
35 38
 
36 39
     @Test
37
-    fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking {
38
-        tlsServerSocket(12321).use { serverSocket ->
39
-            try {
40
-                val socket = KtorLineBufferedSocket("localhost", 12321)
41
-                val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
42
-
43
-                socket.connect()
44
-                assertNotNull(clientSocketAsync.await())
45
-                fail<Unit>()
46
-            } catch (ex : CertificateException) {
47
-                // Expected
48
-            }
49
-        }
50
-    }
51
-
52
-    @Test
53
-    fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
40
+    fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
54 41
         ServerSocket(12321).use { serverSocket ->
55 42
             val socket = KtorLineBufferedSocket("localhost", 12321)
56 43
             val clientBytesAsync = GlobalScope.async {
@@ -60,26 +47,8 @@ internal class KtorLineBufferedSocketTest {
60 47
             }
61 48
 
62 49
             socket.connect()
63
-            socket.sendLine("Hello World".toByteArray())
64
-
65
-            val bytes = clientBytesAsync.await()
66
-            assertNotNull(bytes)
67
-            assertEquals("Hello World\r\n", String(bytes))
68
-        }
69
-    }
70
-
71
-    @Test
72
-    fun `KtorLineBufferedSocket can send a string to a server`() = runBlocking {
73
-        ServerSocket(12321).use { serverSocket ->
74
-            val socket = KtorLineBufferedSocket("localhost", 12321)
75
-            val clientBytesAsync = GlobalScope.async {
76
-                ByteArray(13).apply {
77
-                    serverSocket.accept().getInputStream().read(this)
78
-                }
79
-            }
80
-
81
-            socket.connect()
82
-            socket.sendLine("Hello World")
50
+            GlobalScope.launch { socket.writeLines(writeChannel) }
51
+            writeChannel.send("Hello World".toByteArray())
83 52
 
84 53
             val bytes = clientBytesAsync.await()
85 54
             assertNotNull(bytes)
@@ -99,7 +68,8 @@ internal class KtorLineBufferedSocketTest {
99 68
             }
100 69
 
101 70
             socket.connect()
102
-            socket.sendLine("Hello World")
71
+            GlobalScope.launch { socket.writeLines(writeChannel) }
72
+            writeChannel.send("Hello World".toByteArray())
103 73
 
104 74
             val bytes = clientBytesAsync.await()
105 75
             assertNotNull(bytes)
@@ -107,25 +77,6 @@ internal class KtorLineBufferedSocketTest {
107 77
         }
108 78
     }
109 79
 
110
-    @Test
111
-    fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
112
-        ServerSocket(12321).use { serverSocket ->
113
-            val socket = KtorLineBufferedSocket("localhost", 12321)
114
-            val clientBytesAsync = GlobalScope.async {
115
-                ByteArray(7).apply {
116
-                    serverSocket.accept().getInputStream().read(this)
117
-                }
118
-            }
119
-
120
-            socket.connect()
121
-            socket.sendLine("Hello World".toByteArray(), 6, 5)
122
-
123
-            val bytes = clientBytesAsync.await()
124
-            assertNotNull(bytes)
125
-            assertEquals("World\r\n", String(bytes))
126
-        }
127
-    }
128
-
129 80
     @Test
130 81
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
131 82
         ServerSocket(12321).use { serverSocket ->

Loading…
Cancel
Save