Procházet zdrojové kódy

Rework how writing lines works, again.

To preserve the order of lines we offer them into a channel
without suspending. If we launch a new coroutine there's no
guarantee they'll execute in order, even if the method they
call is fair.
tags/v0.4.0
Chris Smith před 5 roky
rodič
revize
74dd02ca3a

+ 1
- 0
CHANGELOG Zobrazit soubor

2
 
2
 
3
  * Added CtcpReceived and ActionReceived events
3
  * Added CtcpReceived and ActionReceived events
4
  * Added sendCtcp and sendAction message builders
4
  * Added sendCtcp and sendAction message builders
5
+ * Fix issue with messages being sent out of order, which sometimes caused problems connecting to passworded servers
5
 
6
 
6
 v0.3.1
7
 v0.3.1
7
 
8
 

+ 10
- 4
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt Zobrazit soubor

6
 import com.dmdirc.ktirc.model.*
6
 import com.dmdirc.ktirc.model.*
7
 import com.dmdirc.ktirc.util.currentTimeProvider
7
 import com.dmdirc.ktirc.util.currentTimeProvider
8
 import kotlinx.coroutines.*
8
 import kotlinx.coroutines.*
9
+import kotlinx.coroutines.channels.Channel
9
 import kotlinx.coroutines.channels.map
10
 import kotlinx.coroutines.channels.map
10
 import java.util.concurrent.atomic.AtomicBoolean
11
 import java.util.concurrent.atomic.AtomicBoolean
11
 import java.util.logging.Level
12
 import java.util.logging.Level
98
     private val connecting = AtomicBoolean(false)
99
     private val connecting = AtomicBoolean(false)
99
 
100
 
100
     private var connectionJob: Job? = null
101
     private var connectionJob: Job? = null
102
+    internal var writeChannel: Channel<ByteArray>? = null
101
 
103
 
102
     override fun send(message: String) {
104
     override fun send(message: String) {
103
-        // TODO: What happens if sending fails?
104
-        scope.launch {
105
-            socket?.sendLine(message)
106
-        }
105
+        writeChannel?.offer(message.toByteArray())
107
     }
106
     }
108
 
107
 
109
     override fun connect() {
108
     override fun connect() {
112
             with(socketFactory(server.host, server.port, server.tls)) {
111
             with(socketFactory(server.host, server.port, server.tls)) {
113
                 // TODO: Proper error handling - what if connect() fails?
112
                 // TODO: Proper error handling - what if connect() fails?
114
                 socket = this
113
                 socket = this
114
+
115
                 connect()
115
                 connect()
116
+
117
+                with (Channel<ByteArray>(Channel.UNLIMITED)) {
118
+                    writeChannel = this
119
+                    scope.launch { writeLines(this@with) }
120
+                }
121
+
116
                 emitEvent(ServerConnected(currentTimeProvider()))
122
                 emitEvent(ServerConnected(currentTimeProvider()))
117
                 sendCapabilityList()
123
                 sendCapabilityList()
118
                 sendPasswordIfPresent()
124
                 sendPasswordIfPresent()

+ 13
- 22
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt Zobrazit soubor

15
 import kotlinx.coroutines.channels.produce
15
 import kotlinx.coroutines.channels.produce
16
 import kotlinx.coroutines.io.ByteReadChannel
16
 import kotlinx.coroutines.io.ByteReadChannel
17
 import kotlinx.coroutines.io.ByteWriteChannel
17
 import kotlinx.coroutines.io.ByteWriteChannel
18
-import kotlinx.coroutines.sync.Mutex
19
 import java.net.InetSocketAddress
18
 import java.net.InetSocketAddress
20
 import java.security.SecureRandom
19
 import java.security.SecureRandom
21
 import javax.net.ssl.X509TrustManager
20
 import javax.net.ssl.X509TrustManager
25
     suspend fun connect()
24
     suspend fun connect()
26
     fun disconnect()
25
     fun disconnect()
27
 
26
 
28
-    suspend fun sendLine(line: ByteArray, offset: Int = 0, length: Int = line.size)
29
-    suspend fun sendLine(line: String)
30
-
31
     fun readLines(coroutineScope: CoroutineScope): ReceiveChannel<ByteArray>
27
     fun readLines(coroutineScope: CoroutineScope): ReceiveChannel<ByteArray>
28
+    suspend fun writeLines(channel: ReceiveChannel<ByteArray>)
32
 
29
 
33
 }
30
 }
34
 
31
 
46
     var tlsTrustManager: X509TrustManager? = null
43
     var tlsTrustManager: X509TrustManager? = null
47
 
44
 
48
     private val log by logger()
45
     private val log by logger()
49
-    private val writeLock = Mutex()
50
 
46
 
51
     private lateinit var socket: Socket
47
     private lateinit var socket: Socket
52
     private lateinit var readChannel: ByteReadChannel
48
     private lateinit var readChannel: ByteReadChannel
69
         socket.close()
65
         socket.close()
70
     }
66
     }
71
 
67
 
72
-    override suspend fun sendLine(line: ByteArray, offset: Int, length: Int) {
73
-        writeLock.lock()
74
-        try {
75
-            with(writeChannel) {
76
-                log.fine { ">>> ${String(line, offset, length)}" }
77
-                writeAvailable(line, offset, length)
78
-                writeByte(CARRIAGE_RETURN)
79
-                writeByte(LINE_FEED)
80
-                flush()
81
-            }
82
-        } finally {
83
-            writeLock.unlock()
84
-        }
85
-    }
86
-
87
-    override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
88
-
89
     @ExperimentalCoroutinesApi
68
     @ExperimentalCoroutinesApi
90
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
69
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
91
         val lineBuffer = ByteArray(4096)
70
         val lineBuffer = ByteArray(4096)
107
             index = count + index - start
86
             index = count + index - start
108
         }
87
         }
109
     }
88
     }
89
+
90
+    override suspend fun writeLines(channel: ReceiveChannel<ByteArray>) {
91
+        for (line in channel) {
92
+            with(writeChannel) {
93
+                log.fine { ">>> ${String(line)}" }
94
+                writeAvailable(line, 0, line.size)
95
+                writeByte(CARRIAGE_RETURN)
96
+                writeByte(LINE_FEED)
97
+                flush()
98
+            }
99
+        }
100
+    }
110
 }
101
 }

+ 54
- 19
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt Zobrazit soubor

11
 import com.dmdirc.ktirc.model.User
11
 import com.dmdirc.ktirc.model.User
12
 import com.dmdirc.ktirc.util.currentTimeProvider
12
 import com.dmdirc.ktirc.util.currentTimeProvider
13
 import com.nhaarman.mockitokotlin2.*
13
 import com.nhaarman.mockitokotlin2.*
14
-import kotlinx.coroutines.GlobalScope
14
+import kotlinx.coroutines.*
15
 import kotlinx.coroutines.channels.Channel
15
 import kotlinx.coroutines.channels.Channel
16
-import kotlinx.coroutines.launch
17
-import kotlinx.coroutines.runBlocking
18
-import kotlinx.coroutines.withTimeoutOrNull
19
 import org.junit.jupiter.api.Assertions.*
16
 import org.junit.jupiter.api.Assertions.*
20
 import org.junit.jupiter.api.BeforeEach
17
 import org.junit.jupiter.api.BeforeEach
21
 import org.junit.jupiter.api.Test
18
 import org.junit.jupiter.api.Test
97
         client.socketFactory = mockSocketFactory
94
         client.socketFactory = mockSocketFactory
98
         client.connect()
95
         client.connect()
99
 
96
 
100
-        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
101
-            sendLine("CAP LS 302")
102
-            sendLine("NICK :$NICK")
103
-            sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
104
-        }
97
+        client.blockUntilConnected()
98
+
99
+        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
100
+        assertEquals("NICK :$NICK", String(client.writeChannel!!.receive()))
101
+        assertEquals("USER $USER_NAME localhost $HOST :$REAL_NAME", String(client.writeChannel!!.receive()))
105
     }
102
     }
106
 
103
 
107
     @Test
104
     @Test
110
         client.socketFactory = mockSocketFactory
107
         client.socketFactory = mockSocketFactory
111
         client.connect()
108
         client.connect()
112
 
109
 
113
-        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
114
-            sendLine("CAP LS 302")
115
-            sendLine("PASS :$PASSWORD")
116
-            sendLine("NICK :$NICK")
117
-        }
110
+        client.blockUntilConnected()
111
+
112
+        assertEquals("CAP LS 302", String(client.writeChannel!!.receive()))
113
+        assertEquals("PASS :$PASSWORD", String(client.writeChannel!!.receive()))
118
     }
114
     }
119
 
115
 
120
     @Test
116
     @Test
199
         client.socketFactory = mockSocketFactory
195
         client.socketFactory = mockSocketFactory
200
         client.connect()
196
         client.connect()
201
 
197
 
202
-        // Wait for it to connect
203
-        verify(mockSocket, timeout(500)).sendLine("CAP LS 302")
198
+        client.blockUntilConnected()
204
 
199
 
205
         client.send("testing 123")
200
         client.send("testing 123")
206
 
201
 
207
-        verify(mockSocket, timeout(500)).sendLine("testing 123")
202
+        assertEquals(true, withTimeoutOrNull(500) {
203
+            var found = false
204
+            for (line in client.writeChannel!!) {
205
+                if (String(line) == "testing 123") {
206
+                    found = true
207
+                    break
208
+                }
209
+            }
210
+            found
211
+        })
208
     }
212
     }
209
 
213
 
210
     @Test
214
     @Test
213
         client.socketFactory = mockSocketFactory
217
         client.socketFactory = mockSocketFactory
214
         client.connect()
218
         client.connect()
215
 
219
 
216
-        // Wait for it to connect
217
-        verify(mockSocket, timeout(500)).sendLine("CAP LS 302")
220
+        client.blockUntilConnected()
218
 
221
 
219
         client.disconnect()
222
         client.disconnect()
220
 
223
 
221
         verify(mockSocket, timeout(500)).disconnect()
224
         verify(mockSocket, timeout(500)).disconnect()
222
     }
225
     }
223
 
226
 
227
+    @Test
228
+    fun `IrcClientImpl sends messages in order`() = runBlocking {
229
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
230
+        client.socketFactory = mockSocketFactory
231
+        client.connect()
232
+
233
+        client.blockUntilConnected()
234
+
235
+        (0..100).forEach { client.send("TEST $it") }
236
+
237
+        assertEquals(100, withTimeoutOrNull(500) {
238
+            var next = 0
239
+            for (line in client.writeChannel!!) {
240
+                val stringy = String(line)
241
+                if (stringy.startsWith("TEST ")) {
242
+                    assertEquals("TEST $next", stringy)
243
+                    if (++next == 100) {
244
+                        break
245
+                    }
246
+                }
247
+            }
248
+            next
249
+        })
250
+    }
251
+
252
+    private suspend fun IrcClientImpl.blockUntilConnected() {
253
+        // Yuck. Maybe connect should be asynchronous?
254
+        while (writeChannel == null) {
255
+            delay(50)
256
+        }
257
+    }
258
+
224
 
259
 
225
 }
260
 }

+ 10
- 59
src/test/kotlin/com/dmdirc/ktirc/io/KtorLineBufferedSocketTest.kt Zobrazit soubor

3
 import io.ktor.network.tls.certificates.generateCertificate
3
 import io.ktor.network.tls.certificates.generateCertificate
4
 import kotlinx.coroutines.GlobalScope
4
 import kotlinx.coroutines.GlobalScope
5
 import kotlinx.coroutines.async
5
 import kotlinx.coroutines.async
6
+import kotlinx.coroutines.channels.Channel
6
 import kotlinx.coroutines.launch
7
 import kotlinx.coroutines.launch
7
 import kotlinx.coroutines.runBlocking
8
 import kotlinx.coroutines.runBlocking
8
-import org.junit.jupiter.api.Assertions.*
9
+import org.junit.jupiter.api.Assertions.assertEquals
10
+import org.junit.jupiter.api.Assertions.assertNotNull
9
 import org.junit.jupiter.api.Test
11
 import org.junit.jupiter.api.Test
10
 import org.junit.jupiter.api.parallel.Execution
12
 import org.junit.jupiter.api.parallel.Execution
11
 import org.junit.jupiter.api.parallel.ExecutionMode
13
 import org.junit.jupiter.api.parallel.ExecutionMode
12
 import java.io.File
14
 import java.io.File
13
 import java.net.ServerSocket
15
 import java.net.ServerSocket
14
 import java.security.KeyStore
16
 import java.security.KeyStore
15
-import java.security.cert.CertificateException
16
 import java.security.cert.X509Certificate
17
 import java.security.cert.X509Certificate
17
 import javax.net.ssl.KeyManagerFactory
18
 import javax.net.ssl.KeyManagerFactory
18
 import javax.net.ssl.SSLContext
19
 import javax.net.ssl.SSLContext
21
 @Execution(ExecutionMode.SAME_THREAD)
22
 @Execution(ExecutionMode.SAME_THREAD)
22
 internal class KtorLineBufferedSocketTest {
23
 internal class KtorLineBufferedSocketTest {
23
 
24
 
25
+    private val writeChannel = Channel<ByteArray>(Channel.UNLIMITED)
26
+
24
     @Test
27
     @Test
25
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
28
     fun `KtorLineBufferedSocket can connect to a server`() = runBlocking {
26
         ServerSocket(12321).use { serverSocket ->
29
         ServerSocket(12321).use { serverSocket ->
34
     }
37
     }
35
 
38
 
36
     @Test
39
     @Test
37
-    fun `KtorLineBufferedSocket throws trying to connect to a server with a bad TLS cert`() = runBlocking {
38
-        tlsServerSocket(12321).use { serverSocket ->
39
-            try {
40
-                val socket = KtorLineBufferedSocket("localhost", 12321)
41
-                val clientSocketAsync = GlobalScope.async { serverSocket.accept() }
42
-
43
-                socket.connect()
44
-                assertNotNull(clientSocketAsync.await())
45
-                fail<Unit>()
46
-            } catch (ex : CertificateException) {
47
-                // Expected
48
-            }
49
-        }
50
-    }
51
-
52
-    @Test
53
-    fun `KtorLineBufferedSocket can send a whole byte array to a server`() = runBlocking {
40
+    fun `KtorLineBufferedSocket can send a byte array to a server`() = runBlocking {
54
         ServerSocket(12321).use { serverSocket ->
41
         ServerSocket(12321).use { serverSocket ->
55
             val socket = KtorLineBufferedSocket("localhost", 12321)
42
             val socket = KtorLineBufferedSocket("localhost", 12321)
56
             val clientBytesAsync = GlobalScope.async {
43
             val clientBytesAsync = GlobalScope.async {
60
             }
47
             }
61
 
48
 
62
             socket.connect()
49
             socket.connect()
63
-            socket.sendLine("Hello World".toByteArray())
64
-
65
-            val bytes = clientBytesAsync.await()
66
-            assertNotNull(bytes)
67
-            assertEquals("Hello World\r\n", String(bytes))
68
-        }
69
-    }
70
-
71
-    @Test
72
-    fun `KtorLineBufferedSocket can send a string to a server`() = runBlocking {
73
-        ServerSocket(12321).use { serverSocket ->
74
-            val socket = KtorLineBufferedSocket("localhost", 12321)
75
-            val clientBytesAsync = GlobalScope.async {
76
-                ByteArray(13).apply {
77
-                    serverSocket.accept().getInputStream().read(this)
78
-                }
79
-            }
80
-
81
-            socket.connect()
82
-            socket.sendLine("Hello World")
50
+            GlobalScope.launch { socket.writeLines(writeChannel) }
51
+            writeChannel.send("Hello World".toByteArray())
83
 
52
 
84
             val bytes = clientBytesAsync.await()
53
             val bytes = clientBytesAsync.await()
85
             assertNotNull(bytes)
54
             assertNotNull(bytes)
99
             }
68
             }
100
 
69
 
101
             socket.connect()
70
             socket.connect()
102
-            socket.sendLine("Hello World")
71
+            GlobalScope.launch { socket.writeLines(writeChannel) }
72
+            writeChannel.send("Hello World".toByteArray())
103
 
73
 
104
             val bytes = clientBytesAsync.await()
74
             val bytes = clientBytesAsync.await()
105
             assertNotNull(bytes)
75
             assertNotNull(bytes)
107
         }
77
         }
108
     }
78
     }
109
 
79
 
110
-    @Test
111
-    fun `KtorLineBufferedSocket can send a partial byte array to a server`() = runBlocking {
112
-        ServerSocket(12321).use { serverSocket ->
113
-            val socket = KtorLineBufferedSocket("localhost", 12321)
114
-            val clientBytesAsync = GlobalScope.async {
115
-                ByteArray(7).apply {
116
-                    serverSocket.accept().getInputStream().read(this)
117
-                }
118
-            }
119
-
120
-            socket.connect()
121
-            socket.sendLine("Hello World".toByteArray(), 6, 5)
122
-
123
-            val bytes = clientBytesAsync.await()
124
-            assertNotNull(bytes)
125
-            assertEquals("World\r\n", String(bytes))
126
-        }
127
-    }
128
-
129
     @Test
80
     @Test
130
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
81
     fun `KtorLineBufferedSocket can receive a line of CRLF delimited text`() = runBlocking {
131
         ServerSocket(12321).use { serverSocket ->
82
         ServerSocket(12321).use { serverSocket ->

Načítá se…
Zrušit
Uložit