Bläddra i källkod

Rework coroutines to make more sense.

tags/v0.2.0
Chris Smith 5 år sedan
förälder
incheckning
3686c5aa2d

+ 1
- 0
CHANGELOG Visa fil

2
 
2
 
3
   * Added support for connecting over TLS
3
   * Added support for connecting over TLS
4
   * Simplified how event handlers are registered
4
   * Simplified how event handlers are registered
5
+  * Improved use of coroutines so users don't have to worry about them
5
   * (Internal) Upgraded to Gradle 5.1.1
6
   * (Internal) Upgraded to Gradle 5.1.1

+ 25
- 17
src/main/kotlin/com/dmdirc/ktirc/IrcClient.kt Visa fil

4
 import com.dmdirc.ktirc.io.*
4
 import com.dmdirc.ktirc.io.*
5
 import com.dmdirc.ktirc.messages.*
5
 import com.dmdirc.ktirc.messages.*
6
 import com.dmdirc.ktirc.model.*
6
 import com.dmdirc.ktirc.model.*
7
+import kotlinx.coroutines.*
7
 import kotlinx.coroutines.channels.map
8
 import kotlinx.coroutines.channels.map
8
-import kotlinx.coroutines.coroutineScope
9
-import kotlinx.coroutines.runBlocking
9
+import java.util.concurrent.atomic.AtomicBoolean
10
 import java.util.logging.Level
10
 import java.util.logging.Level
11
 import java.util.logging.LogManager
11
 import java.util.logging.LogManager
12
 
12
 
13
 
13
 
14
 interface IrcClient {
14
 interface IrcClient {
15
 
15
 
16
-    suspend fun send(message: String)
16
+    fun send(message: String)
17
 
17
 
18
     val serverState: ServerState
18
     val serverState: ServerState
19
     val channelState: ChannelStateMap
19
     val channelState: ChannelStateMap
44
     private val parser = MessageParser()
44
     private val parser = MessageParser()
45
     private var socket: LineBufferedSocket? = null
45
     private var socket: LineBufferedSocket? = null
46
 
46
 
47
-    // TODO: It would be cleaner if this didn't suspend but returned immediately
48
-    override suspend fun send(message: String) {
49
-        socket?.sendLine(message)
47
+    private val scope = CoroutineScope(Dispatchers.IO)
48
+    private val connecting = AtomicBoolean(false)
49
+
50
+    private var connectionJob: Job? = null
51
+
52
+    override fun send(message: String) {
53
+        scope.launch {
54
+            socket?.sendLine(message)
55
+        }
50
     }
56
     }
51
 
57
 
52
-    suspend fun connect() {
53
-        // TODO: Concurrency!
54
-        check(socket == null)
55
-        coroutineScope {
58
+    fun connect() {
59
+        check(!connecting.getAndSet(true))
60
+        connectionJob = scope.launch {
56
             with(socketFactory(server.host, server.port, server.tls)) {
61
             with(socketFactory(server.host, server.port, server.tls)) {
57
                 socket = this
62
                 socket = this
58
                 connect()
63
                 connect()
62
                 // TODO: Send correct host
67
                 // TODO: Send correct host
63
                 sendLine(userMessage(profile.userName, "localhost", server.host, profile.realName))
68
                 sendLine(userMessage(profile.userName, "localhost", server.host, profile.realName))
64
                 // TODO: This should be elsewhere
69
                 // TODO: This should be elsewhere
65
-                messageHandler.processMessages(this@IrcClientImpl, readLines(this@coroutineScope).map { parser.parse(it) })
70
+                messageHandler.processMessages(this@IrcClientImpl, readLines(scope).map { parser.parse(it) })
66
             }
71
             }
67
         }
72
         }
68
     }
73
     }
71
         socket?.disconnect()
76
         socket?.disconnect()
72
     }
77
     }
73
 
78
 
79
+    suspend fun join() {
80
+        connectionJob?.join()
81
+    }
82
+
74
     override fun onEvent(handler: (IrcEvent) -> Unit) {
83
     override fun onEvent(handler: (IrcEvent) -> Unit) {
75
         messageHandler.handlers.add(object : EventHandler {
84
         messageHandler.handlers.add(object : EventHandler {
76
-            override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
85
+            override fun processEvent(client: IrcClient, event: IrcEvent) {
77
                 handler(event)
86
                 handler(event)
78
             }
87
             }
79
         })
88
         })
90
     runBlocking {
99
     runBlocking {
91
         val client = IrcClientImpl(Server("testnet.inspircd.org", 6667), Profile("KtIrc", "Kotlin!", "kotlin"))
100
         val client = IrcClientImpl(Server("testnet.inspircd.org", 6667), Profile("KtIrc", "Kotlin!", "kotlin"))
92
         client.onEvent { event ->
101
         client.onEvent { event ->
93
-            runBlocking {
94
-                when (event) {
95
-                    is ServerWelcome -> client.send(joinMessage("#ktirc"))
96
-                    is MessageReceived -> if (event.message == "!test") client.send(privmsgMessage(event.target, "Test successful!"))
97
-                }
102
+            when (event) {
103
+                is ServerWelcome -> client.send(joinMessage("#ktirc"))
104
+                is MessageReceived -> if (event.message == "!test") client.send(privmsgMessage(event.target, "Test successful!"))
98
             }
105
             }
99
         }
106
         }
100
         client.connect()
107
         client.connect()
108
+        client.join()
101
     }
109
     }
102
 }
110
 }

+ 3
- 3
src/main/kotlin/com/dmdirc/ktirc/events/CapabilitiesHandler.kt Visa fil

12
 
12
 
13
     private val log by logger()
13
     private val log by logger()
14
 
14
 
15
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
15
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
16
         when (event) {
16
         when (event) {
17
             is ServerCapabilitiesReceived -> handleCapabilitiesReceived(client.serverState.capabilities, event.capabilities)
17
             is ServerCapabilitiesReceived -> handleCapabilitiesReceived(client.serverState.capabilities, event.capabilities)
18
             is ServerCapabilitiesFinished -> handleCapabilitiesFinished(client)
18
             is ServerCapabilitiesFinished -> handleCapabilitiesFinished(client)
24
         state.advertisedCapabilities.putAll(capabilities)
24
         state.advertisedCapabilities.putAll(capabilities)
25
     }
25
     }
26
 
26
 
27
-    private suspend fun handleCapabilitiesFinished(client: IrcClient) {
27
+    private fun handleCapabilitiesFinished(client: IrcClient) {
28
         // TODO: We probably need to split the outgoing REQ lines if there are lots of caps
28
         // TODO: We probably need to split the outgoing REQ lines if there are lots of caps
29
         // TODO: For caps with values we'll need to decide which value to use/whether to enable them/etc
29
         // TODO: For caps with values we'll need to decide which value to use/whether to enable them/etc
30
         with (client.serverState.capabilities) {
30
         with (client.serverState.capabilities) {
41
         }
41
         }
42
     }
42
     }
43
 
43
 
44
-    private suspend fun handleCapabilitiesAcknowledged(client: IrcClient, capabilities: Map<Capability, String>) {
44
+    private fun handleCapabilitiesAcknowledged(client: IrcClient, capabilities: Map<Capability, String>) {
45
         // TODO: Check if everything we wanted is enabled
45
         // TODO: Check if everything we wanted is enabled
46
         with (client.serverState.capabilities) {
46
         with (client.serverState.capabilities) {
47
             log.info { "Acknowledged capabilities: ${capabilities.keys.map { it.name }.toList()}" }
47
             log.info { "Acknowledged capabilities: ${capabilities.keys.map { it.name }.toList()}" }

+ 1
- 3
src/main/kotlin/com/dmdirc/ktirc/events/ChannelStateHandler.kt Visa fil

9
 
9
 
10
     private val log by logger()
10
     private val log by logger()
11
 
11
 
12
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
12
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
13
         when (event) {
13
         when (event) {
14
             is ChannelJoined -> handleJoin(client, event)
14
             is ChannelJoined -> handleJoin(client, event)
15
             is ChannelParted -> handlePart(client, event)
15
             is ChannelParted -> handlePart(client, event)
64
         client.channelState.forEach { it.users -= event.user.nickname }
64
         client.channelState.forEach { it.users -= event.user.nickname }
65
     }
65
     }
66
 
66
 
67
-    private fun String.nickname(prefixLength: Int) = substring(prefixLength).substringBefore('!')
68
-
69
 }
67
 }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/EventHandler.kt Visa fil

5
 @FunctionalInterface
5
 @FunctionalInterface
6
 interface EventHandler {
6
 interface EventHandler {
7
 
7
 
8
-    suspend fun processEvent(client: IrcClient, event: IrcEvent)
8
+    fun processEvent(client: IrcClient, event: IrcEvent)
9
 
9
 
10
 }
10
 }
11
 
11
 

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/PingHandler.kt Visa fil

5
 
5
 
6
 class PingHandler : EventHandler {
6
 class PingHandler : EventHandler {
7
 
7
 
8
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
8
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
9
         when (event) {
9
         when (event) {
10
             is PingReceived -> client.send(pongMessage(event.nonce))
10
             is PingReceived -> client.send(pongMessage(event.nonce))
11
         }
11
         }

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/ServerStateHandler.kt Visa fil

4
 
4
 
5
 class ServerStateHandler : EventHandler {
5
 class ServerStateHandler : EventHandler {
6
 
6
 
7
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
7
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
8
         when (event) {
8
         when (event) {
9
             is ServerWelcome -> client.serverState.localNickname = event.localNick
9
             is ServerWelcome -> client.serverState.localNickname = event.localNick
10
             is ServerFeaturesUpdated -> client.serverState.features.setAll(event.serverFeatures)
10
             is ServerFeaturesUpdated -> client.serverState.features.setAll(event.serverFeatures)

+ 1
- 1
src/main/kotlin/com/dmdirc/ktirc/events/UserStateHandler.kt Visa fil

5
 
5
 
6
 class UserStateHandler : EventHandler {
6
 class UserStateHandler : EventHandler {
7
 
7
 
8
-    override suspend fun processEvent(client: IrcClient, event: IrcEvent) {
8
+    override fun processEvent(client: IrcClient, event: IrcEvent) {
9
         when (event) {
9
         when (event) {
10
             is ChannelJoined -> handleJoin(client.userState, event)
10
             is ChannelJoined -> handleJoin(client.userState, event)
11
             is ChannelParted -> handlePart(client, event)
11
             is ChannelParted -> handlePart(client, event)

+ 2
- 0
src/main/kotlin/com/dmdirc/ktirc/io/LineBufferedSocket.kt Visa fil

9
 import io.ktor.network.tls.tls
9
 import io.ktor.network.tls.tls
10
 import kotlinx.coroutines.CoroutineScope
10
 import kotlinx.coroutines.CoroutineScope
11
 import kotlinx.coroutines.Dispatchers
11
 import kotlinx.coroutines.Dispatchers
12
+import kotlinx.coroutines.ExperimentalCoroutinesApi
12
 import kotlinx.coroutines.GlobalScope
13
 import kotlinx.coroutines.GlobalScope
13
 import kotlinx.coroutines.channels.ReceiveChannel
14
 import kotlinx.coroutines.channels.ReceiveChannel
14
 import kotlinx.coroutines.channels.produce
15
 import kotlinx.coroutines.channels.produce
78
 
79
 
79
     override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
80
     override suspend fun sendLine(line: String) = sendLine(line.toByteArray())
80
 
81
 
82
+    @ExperimentalCoroutinesApi
81
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
83
     override fun readLines(coroutineScope: CoroutineScope) = coroutineScope.produce {
82
         val lineBuffer = ByteArray(4096)
84
         val lineBuffer = ByteArray(4096)
83
         var index = 0
85
         var index = 0

+ 76
- 62
src/test/kotlin/com/dmdirc/ktirc/IrcClientTest.kt Visa fil

6
 import com.dmdirc.ktirc.io.LineBufferedSocket
6
 import com.dmdirc.ktirc.io.LineBufferedSocket
7
 import com.dmdirc.ktirc.model.*
7
 import com.dmdirc.ktirc.model.*
8
 import com.nhaarman.mockitokotlin2.*
8
 import com.nhaarman.mockitokotlin2.*
9
+import kotlinx.coroutines.GlobalScope
9
 import kotlinx.coroutines.channels.Channel
10
 import kotlinx.coroutines.channels.Channel
10
 import kotlinx.coroutines.launch
11
 import kotlinx.coroutines.launch
11
 import kotlinx.coroutines.runBlocking
12
 import kotlinx.coroutines.runBlocking
13
+import kotlinx.coroutines.withTimeoutOrNull
12
 import org.junit.jupiter.api.Assertions.*
14
 import org.junit.jupiter.api.Assertions.*
13
 import org.junit.jupiter.api.BeforeEach
15
 import org.junit.jupiter.api.BeforeEach
14
 import org.junit.jupiter.api.Test
16
 import org.junit.jupiter.api.Test
44
 
46
 
45
     @Test
47
     @Test
46
     fun `IrcClientImpl uses socket factory to create a new socket on connect`() {
48
     fun `IrcClientImpl uses socket factory to create a new socket on connect`() {
47
-        runBlocking {
48
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
49
-            client.socketFactory = mockSocketFactory
50
-            readLineChannel.close()
51
-
52
-            client.connect()
49
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
50
+        client.socketFactory = mockSocketFactory
51
+        client.connect()
53
 
52
 
54
-            verify(mockSocketFactory).invoke(HOST, PORT, false)
55
-        }
53
+        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, false)
56
     }
54
     }
57
 
55
 
58
     @Test
56
     @Test
59
     fun `IrcClientImpl uses socket factory to create a new tls on connect`() {
57
     fun `IrcClientImpl uses socket factory to create a new tls on connect`() {
60
-        runBlocking {
61
-            val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
62
-            client.socketFactory = mockSocketFactory
63
-            readLineChannel.close()
64
-
65
-            client.connect()
58
+        val client = IrcClientImpl(Server(HOST, PORT, true), Profile(NICK, REAL_NAME, USER_NAME))
59
+        client.socketFactory = mockSocketFactory
60
+        client.connect()
66
 
61
 
67
-            verify(mockSocketFactory).invoke(HOST, PORT, true)
68
-        }
62
+        verify(mockSocketFactory, timeout(500)).invoke(HOST, PORT, true)
69
     }
63
     }
70
 
64
 
71
     @Test
65
     @Test
72
     fun `IrcClientImpl throws if socket already exists`() {
66
     fun `IrcClientImpl throws if socket already exists`() {
73
-        runBlocking {
74
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
75
-            client.socketFactory = mockSocketFactory
76
-            readLineChannel.close()
67
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
68
+        client.socketFactory = mockSocketFactory
69
+        client.connect()
77
 
70
 
71
+        assertThrows<IllegalStateException> {
78
             client.connect()
72
             client.connect()
79
-
80
-            assertThrows<IllegalStateException> {
81
-                runBlocking {
82
-                    client.connect()
83
-                }
84
-            }
85
         }
73
         }
86
     }
74
     }
87
 
75
 
88
     @Test
76
     @Test
89
-    fun `IrcClientImpl sends basic connection strings`() {
90
-        runBlocking {
91
-            val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
92
-            client.socketFactory = mockSocketFactory
93
-            readLineChannel.close()
94
-
95
-            client.connect()
77
+    fun `IrcClientImpl sends basic connection strings`() = runBlocking {
78
+        val client = IrcClientImpl(Server(HOST, PORT), Profile(NICK, REAL_NAME, USER_NAME))
79
+        client.socketFactory = mockSocketFactory
80
+        client.connect()
96
 
81
 
97
-            with(inOrder(mockSocket).verify(mockSocket)) {
98
-                sendLine("CAP LS 302")
99
-                sendLine("NICK :$NICK")
100
-                sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
101
-            }
82
+        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
83
+            sendLine("CAP LS 302")
84
+            sendLine("NICK :$NICK")
85
+            sendLine("USER $USER_NAME localhost $HOST :$REAL_NAME")
102
         }
86
         }
103
     }
87
     }
104
 
88
 
105
     @Test
89
     @Test
106
-    fun `IrcClientImpl sends password first, when present`() {
107
-        runBlocking {
108
-            val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
109
-            client.socketFactory = mockSocketFactory
110
-            readLineChannel.close()
111
-
112
-            client.connect()
113
-
114
-            with(inOrder(mockSocket).verify(mockSocket)) {
115
-                sendLine("CAP LS 302")
116
-                sendLine("PASS :$PASSWORD")
117
-                sendLine("NICK :$NICK")
118
-            }
90
+    fun `IrcClientImpl sends password first, when present`() = runBlocking {
91
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
92
+        client.socketFactory = mockSocketFactory
93
+        client.connect()
94
+
95
+        with(inOrder(mockSocket).verify(mockSocket, timeout(500))) {
96
+            sendLine("CAP LS 302")
97
+            sendLine("PASS :$PASSWORD")
98
+            sendLine("NICK :$NICK")
119
         }
99
         }
120
     }
100
     }
121
 
101
 
122
     @Test
102
     @Test
123
     fun `IrcClientImpl sends events to provided event handler`() {
103
     fun `IrcClientImpl sends events to provided event handler`() {
124
-        runBlocking {
125
-            val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
126
-            client.socketFactory = mockSocketFactory
127
-            client.onEvent(mockEventHandler)
104
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
105
+        client.socketFactory = mockSocketFactory
106
+        client.onEvent(mockEventHandler)
128
 
107
 
129
-            launch {
130
-                readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
131
-                readLineChannel.close()
132
-            }
108
+        GlobalScope.launch {
109
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
110
+        }
133
 
111
 
134
-            client.connect()
112
+        client.connect()
135
 
113
 
136
-            verify(mockEventHandler).invoke(isA<ServerWelcome>())
137
-        }
114
+        verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
138
     }
115
     }
139
 
116
 
140
     @Test
117
     @Test
161
         assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
138
         assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
162
     }
139
     }
163
 
140
 
141
+    @Test
142
+    fun `IrcClientImpl join blocks when socket is open`() {
143
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
144
+        client.socketFactory = mockSocketFactory
145
+
146
+        GlobalScope.launch {
147
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
148
+        }
149
+
150
+        client.connect()
151
+        runBlocking {
152
+            assertNull(withTimeoutOrNull(100L) {
153
+                client.join()
154
+                true
155
+            })
156
+        }
157
+    }
158
+
159
+    @Test
160
+    fun `IrcClientImpl join returns when socket is closed`() {
161
+        val client = IrcClientImpl(Server(HOST, PORT, password = PASSWORD), Profile(NICK, REAL_NAME, USER_NAME))
162
+        client.socketFactory = mockSocketFactory
163
+
164
+        GlobalScope.launch {
165
+            readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
166
+            readLineChannel.close()
167
+        }
168
+
169
+        client.connect()
170
+        runBlocking {
171
+            assertEquals(true, withTimeoutOrNull(500L) {
172
+                client.join()
173
+                true
174
+            })
175
+        }
176
+    }
177
+
164
 }
178
 }

Laddar…
Avbryt
Spara