Browse Source

Rework coroutines to make more sense.

tags/v0.2.0
Chris Smith 5 years ago
parent
commit
3686c5aa2d

+ 1
- 0
CHANGELOG View File

@@ -2,4 +2,5 @@ v0.2.0 [in development]
2 2
 
3 3
   * Added support for connecting over TLS
4 4
   * Simplified how event handlers are registered
5
+  * Improved use of coroutines so users don't have to worry about them
5 6
   * (Internal) Upgraded to Gradle 5.1.1

+ 25
- 17
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt View File

@@ -4,16 +4,16 @@ import com.dmdirc.ktirc.events.*
4 4
 import com.dmdirc.ktirc.io.*
5 5
 import com.dmdirc.ktirc.messages.*
6 6
 import com.dmdirc.ktirc.model.*
7
+import kotlinx.coroutines.*
7 8
 import kotlinx.coroutines.channels.map
8
-import kotlinx.coroutines.coroutineScope
9
-import kotlinx.coroutines.runBlocking
9
+import java.util.concurrent.atomic.AtomicBoolean
10 10
 import java.util.logging.Level
11 11
 import java.util.logging.LogManager
12 12
 
13 13
 
14 14
 interface IrcClient {
15 15
 
16
-    suspend fun send(message: String)
16
+    fun send(message: String)
17 17
 
18 18
     val serverState: ServerState
19 19
     val channelState: ChannelStateMap
@@ -44,15 +44,20 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
44 44
     private val parser = MessageParser()
45 45
     private var socket: LineBufferedSocket? = null
46 46
 
47
-    // TODO: It would be cleaner if this didn't suspend but returned immediately
48
-    override suspend fun send(message: String) {
49
-        socket?.sendLine(message)
47
+    private val scope = CoroutineScope(Dispatchers.IO)
48
+    private val connecting = AtomicBoolean(false)
49
+
50
+    private var connectionJob: Job? = null
51
+
52
+    override fun send(message: String) {
53
+        scope.launch {
54
+            socket?.sendLine(message)
55
+        }
50 56
     }
51 57
 
52
-    suspend fun connect() {
53
-        // TODO: Concurrency!
54
-        check(socket == null)
55
-        coroutineScope {
58
+    fun connect() {
59
+        check(!connecting.getAndSet(true))
60
+        connectionJob = scope.launch {
56 61
             with(socketFactory(server.host, server.port, server.tls)) {
57 62
                 socket = this
58 63
                 connect()
@@ -62,7 +67,7 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
62 67
                 // TODO: Send correct host
63 68
                 sendLine(userMessage(profile.userName, "localhost", server.host, profile.realName))
64 69
                 // TODO: This should be elsewhere
65
-                messageHandler.processMessages(this@IrcClientImpl, readLines(this@coroutineScope).map { parser.parse(it) })
70
+                messageHandler.processMessages(this@IrcClientImpl, readLines(scope).map { parser.parse(it) })
66 71
             }
67 72
         }
68 73
     }
@@ -71,9 +76,13 @@ class IrcClientImpl(private val server: Server, private val profile: Profile) :
71 76
         socket?.disconnect()
72 77
     }
73 78
 
79
+    suspend fun join() {
80
+        connectionJob?.join()
81
+    }
82
+
74 83
     override fun onEvent(handler: (IrcEvent) -> Unit) {
75 84
         messageHandler.handlers.add(object : EventHandler {
76
-            override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
85
+            override fun processEvent(client: IrcClient, event: IrcEvent) {
77 86
                 handler(event)
78 87
             }
79 88
         })
@@ -90,13 +99,12 @@ fun main() {
90 99
     runBlocking {
91 100
         val client = IrcClientImpl(Server("testnet.inspircd.org", 6667), Profile("KtIrc", "Kotlin!", "kotlin"))
92 101
         client.onEvent { event ->
93
-            runBlocking {
94
-                when (event) {
95
-                    is ServerWelcome -> client.send(joinMessage("#ktirc"))
96
-                    is MessageReceived -> if (event.message == "!test") client.send(privmsgMessage(event.target, "Test successful!"))
97
-                }
102
+            when (event) {
103
+                is ServerWelcome -> client.send(joinMessage("#ktirc"))
104
+                is MessageReceived -> if (event.message == "!test") client.send(privmsgMessage(event.target, "Test successful!"))
98 105
             }
99 106
         }
100 107
         client.connect()
108
+        client.join()
101 109
     }
102 110
 }

+ 3
- 3
src/main/kotlin/com/dmdirc/ktirc/events/CapabilitiesHandler.kt View File

@@ -12,7 +12,7 @@ class CapabilitiesHandler : EventHandler {
12 12
 
13 13
     private val log by logger()
14 14
 
15
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
15
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
16 16
         when (event) {
17 17
             is ServerCapabilitiesReceived -> handleCapabilitiesReceived(client.serverState.capabilities, event.capabilities)
18 18
             is ServerCapabilitiesFinished -> handleCapabilitiesFinished(client)
@@ -24,7 +24,7 @@ class CapabilitiesHandler : EventHandler {
24 24
         state.advertisedCapabilities.putAll(capabilities)
25 25
     }
26 26
 
27
-    private suspend fun handleCapabilitiesFinished(client: IrcClient) {
27
+    private fun handleCapabilitiesFinished(client: IrcClient) {
28 28
         // TODO: We probably need to split the outgoing REQ lines if there are lots of caps
29 29
         // TODO: For caps with values we'll need to decide which value to use/whether to enable them/etc
30 30
         with (client.serverState.capabilities) {
@@ -41,7 +41,7 @@ class CapabilitiesHandler : EventHandler {
41 41
         }
42 42
     }
43 43
 
44
-    private suspend fun handleCapabilitiesAcknowledged(client: IrcClient, capabilities: Map<Capability, String>) {
44
+    private fun handleCapabilitiesAcknowledged(client: IrcClient, capabilities: Map<Capability, String>) {
45 45
         // TODO: Check if everything we wanted is enabled
46 46
         with (client.serverState.capabilities) {
47 47
             log.info { "Acknowledged capabilities: ${capabilities.keys.map { it.name }.toList()}" }

+ 1
- 3
src/main/kotlin/com/dmdirc/ktirc/events/ChannelStateHandler.kt View File

@@ -9,7 +9,7 @@ class ChannelStateHandler : EventHandler {
9 9
 
10 10
     private val log by logger()
11 11
 
12
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
12
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
13 13
         when (event) {
14 14
             is ChannelJoined -> handleJoin(client, event)
15 15
             is ChannelParted -> handlePart(client, event)
@@ -64,6 +64,4 @@ class ChannelStateHandler : EventHandler {
64 64
         client.channelState.forEach { it.users -= event.user.nickname }
65 65
     }
66 66
 
67
-    private fun String.nickname(prefixLength: Int) = substring(prefixLength).substringBefore('!')
68
-
69 67
 }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/EventHandler.kt View File

@@ -5,7 +5,7 @@ import com.dmdirc.ktirc.IrcClient
5 5
 @FunctionalInterface
6 6
 interface EventHandler {
7 7
 
8
-    suspend fun processEvent(client: IrcClient, event: IrcEvent)
8
+    fun processEvent(client: IrcClient, event: IrcEvent)
9 9
 
10 10
 }
11 11
 

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/PingHandler.kt View File

@@ -5,7 +5,7 @@ import com.dmdirc.ktirc.messages.pongMessage
5 5
 
6 6
 class PingHandler : EventHandler {
7 7
 
8
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
8
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
9 9
         when (event) {
10 10
             is PingReceived -> client.send(pongMessage(event.nonce))
11 11
         }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/ServerStateHandler.kt View File

@@ -4,7 +4,7 @@ import com.dmdirc.ktirc.IrcClient
4 4
 
5 5
 class ServerStateHandler : EventHandler {
6 6
 
7
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
7
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
8 8
         when (event) {
9 9
             is ServerWelcome -> client.serverState.localNickname = event.localNick
10 10
             is ServerFeaturesUpdated -> client.serverState.features.setAll(event.serverFeatures)

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/UserStateHandler.kt View File

@@ -5,7 +5,7 @@ import com.dmdirc.ktirc.model.UserState
5 5
 
6 6
 class UserStateHandler : EventHandler {
7 7
 
8
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
8
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
9 9
         when (event) {
10 10
             is ChannelJoined -> handleJoin(client.userState, event)
11 11
             is ChannelParted -> handlePart(client, event)

+ 2
- 0
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt View File

@@ -9,6 +9,7 @@ import io.ktor.network.sockets.openWriteChannel
9 9
 import io.ktor.network.tls.tls
10 10
 import kotlinx.coroutines.CoroutineScope
11 11
 import kotlinx.coroutines.Dispatchers
12
+import kotlinx.coroutines.ExperimentalCoroutinesApi
12 13
 import kotlinx.coroutines.GlobalScope
13 14
 import kotlinx.coroutines.channels.ReceiveChannel
14 15
 import kotlinx.coroutines.channels.produce
@@ -78,6 +79,7 @@ class KtorLineBufferedSocket(private val host: String, private val port: Int, pr
78 79
 
79 80
     override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
80 81
 
82
+    @ExperimentalCoroutinesApi
81 83
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
82 84
         val lineBuffer = ByteArray(4096)
83 85
         var index = 0

+ 76
- 62
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt View File

@@ -6,9 +6,11 @@ import com.dmdirc.ktirc.io.CaseMapping
6 6
 import com.dmdirc.ktirc.io.LineBufferedSocket
7 7
 import com.dmdirc.ktirc.model.*
8 8
 import com.nhaarman.mockitokotlin2.*
9
+import kotlinx.coroutines.GlobalScope
9 10
 import kotlinx.coroutines.channels.Channel
10 11
 import kotlinx.coroutines.launch
11 12
 import kotlinx.coroutines.runBlocking
13
+import kotlinx.coroutines.withTimeoutOrNull
12 14
 import org.junit.jupiter.api.Assertions.*
13 15
 import org.junit.jupiter.api.BeforeEach
14 16
 import org.junit.jupiter.api.Test
@@ -44,97 +46,72 @@ internal class IrcClientImplTest {
44 46
 
45 47
     @Test
46 48
     fun `IrcClientImpl uses socket factory to create a new socket on connect`() {
47
-        runBlocking {
48
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
49
-            client.socketFactory = mockSocketFactory
50
-            readLineChannel.close()
51
-
52
-            client.connect()
49
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
50
+        client.socketFactory = mockSocketFactory
51
+        client.connect()
53 52
 
54
-            verify(mockSocketFactory).invoke(HOST, PORT, false)
55
-        }
53
+        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, false)
56 54
     }
57 55
 
58 56
     @Test
59 57
     fun `IrcClientImpl uses socket factory to create a new tls on connect`() {
60
-        runBlocking {
61
-            val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
62
-            client.socketFactory = mockSocketFactory
63
-            readLineChannel.close()
64
-
65
-            client.connect()
58
+        val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
59
+        client.socketFactory = mockSocketFactory
60
+        client.connect()
66 61
 
67
-            verify(mockSocketFactory).invoke(HOST, PORT, true)
68
-        }
62
+        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, true)
69 63
     }
70 64
 
71 65
     @Test
72 66
     fun `IrcClientImpl throws if socket already exists`() {
73
-        runBlocking {
74
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
75
-            client.socketFactory = mockSocketFactory
76
-            readLineChannel.close()
67
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
68
+        client.socketFactory = mockSocketFactory
69
+        client.connect()
77 70
 
71
+        assertThrows<IllegalStateException> {
78 72
             client.connect()
79
-
80
-            assertThrows<IllegalStateException> {
81
-                runBlocking {
82
-                    client.connect()
83
-                }
84
-            }
85 73
         }
86 74
     }
87 75
 
88 76
     @Test
89
-    fun `IrcClientImpl sends basic connection strings`() {
90
-        runBlocking {
91
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
92
-            client.socketFactory = mockSocketFactory
93
-            readLineChannel.close()
94
-
95
-            client.connect()
77
+    fun `IrcClientImpl sends basic connection strings`() = runBlocking {
78
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
79
+        client.socketFactory = mockSocketFactory
80
+        client.connect()
96 81
 
97
-            with(inOrder(mockSocket).verify(mockSocket)) {
98
-                sendLine("CAP LS 302")
99
-                sendLine("NICK :$NICK")
100
-                sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
101
-            }
82
+        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
83
+            sendLine("CAP LS 302")
84
+            sendLine("NICK :$NICK")
85
+            sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
102 86
         }
103 87
     }
104 88
 
105 89
     @Test
106
-    fun `IrcClientImpl sends password first, when present`() {
107
-        runBlocking {
108
-            val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
109
-            client.socketFactory = mockSocketFactory
110
-            readLineChannel.close()
111
-
112
-            client.connect()
113
-
114
-            with(inOrder(mockSocket).verify(mockSocket)) {
115
-                sendLine("CAP LS 302")
116
-                sendLine("PASS :$PASSWORD")
117
-                sendLine("NICK :$NICK")
118
-            }
90
+    fun `IrcClientImpl sends password first, when present`() = runBlocking {
91
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
92
+        client.socketFactory = mockSocketFactory
93
+        client.connect()
94
+
95
+        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
96
+            sendLine("CAP LS 302")
97
+            sendLine("PASS :$PASSWORD")
98
+            sendLine("NICK :$NICK")
119 99
         }
120 100
     }
121 101
 
122 102
     @Test
123 103
     fun `IrcClientImpl sends events to provided event handler`() {
124
-        runBlocking {
125
-            val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
126
-            client.socketFactory = mockSocketFactory
127
-            client.onEvent(mockEventHandler)
104
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
105
+        client.socketFactory = mockSocketFactory
106
+        client.onEvent(mockEventHandler)
128 107
 
129
-            launch {
130
-                readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
131
-                readLineChannel.close()
132
-            }
108
+        GlobalScope.launch {
109
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
110
+        }
133 111
 
134
-            client.connect()
112
+        client.connect()
135 113
 
136
-            verify(mockEventHandler).invoke(isA<ServerWelcome>())
137
-        }
114
+        verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
138 115
     }
139 116
 
140 117
     @Test
@@ -161,4 +138,41 @@ internal class IrcClientImplTest {
161 138
         assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
162 139
     }
163 140
 
141
+    @Test
142
+    fun `IrcClientImpl join blocks when socket is open`() {
143
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
144
+        client.socketFactory = mockSocketFactory
145
+
146
+        GlobalScope.launch {
147
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
148
+        }
149
+
150
+        client.connect()
151
+        runBlocking {
152
+            assertNull(withTimeoutOrNull(100L) {
153
+                client.join()
154
+                true
155
+            })
156
+        }
157
+    }
158
+
159
+    @Test
160
+    fun `IrcClientImpl join returns when socket is closed`() {
161
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
162
+        client.socketFactory = mockSocketFactory
163
+
164
+        GlobalScope.launch {
165
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
166
+            readLineChannel.close()
167
+        }
168
+
169
+        client.connect()
170
+        runBlocking {
171
+            assertEquals(true, withTimeoutOrNull(500L) {
172
+                client.join()
173
+                true
174
+            })
175
+        }
176
+    }
177
+
164 178
 }

Loading…
Cancel
Save