You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.currentTimeProvider
  8. import com.dmdirc.ktirc.util.generateLabel
  9. import io.mockk.*
  10. import kotlinx.coroutines.*
  11. import kotlinx.coroutines.channels.Channel
  12. import kotlinx.coroutines.channels.filter
  13. import kotlinx.coroutines.channels.map
  14. import kotlinx.coroutines.sync.Mutex
  15. import org.junit.jupiter.api.Assertions.*
  16. import org.junit.jupiter.api.BeforeEach
  17. import org.junit.jupiter.api.Test
  18. import org.junit.jupiter.api.assertThrows
  19. import java.nio.channels.UnresolvedAddressException
  20. import java.security.cert.CertificateException
  21. import java.util.concurrent.atomic.AtomicReference
  22. @ExperimentalCoroutinesApi
  23. internal class IrcClientImplTest {
  24. companion object {
  25. private const val HOST = "thegibson.com"
  26. private const val PORT = 12345
  27. private const val NICK = "AcidBurn"
  28. private const val REAL_NAME = "Kate Libby"
  29. private const val USER_NAME = "acidb"
  30. private const val PASSWORD = "HackThePlanet"
  31. }
  32. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  33. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  34. private val mockSocket = mockk<LineBufferedSocket> {
  35. every { receiveChannel } returns readLineChannel
  36. every { sendChannel } returns sendLineChannel
  37. }
  38. private val mockSocketFactory = mockk<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  39. every { this@mockk.invoke(any(), eq(HOST), eq(PORT), any()) } returns mockSocket
  40. }
  41. private val mockEventHandler = mockk<(IrcEvent) -> Unit> {
  42. every { this@mockk.invoke(any()) } just Runs
  43. }
  44. private val profileConfig = ProfileConfig().apply {
  45. nickname = NICK
  46. realName = REAL_NAME
  47. username = USER_NAME
  48. }
  49. private val serverConfig = ServerConfig().apply {
  50. host = HOST
  51. port = PORT
  52. }
  53. private val normalConfig = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig(), null)
  54. @BeforeEach
  55. fun setUp() {
  56. currentTimeProvider = { TestConstants.time }
  57. }
  58. @Test
  59. fun `uses socket factory to create a new socket on connect`() {
  60. val client = IrcClientImpl(normalConfig)
  61. client.socketFactory = mockSocketFactory
  62. client.connect()
  63. verify(timeout = 500) { mockSocketFactory(client, HOST, PORT, false) }
  64. }
  65. @Test
  66. fun `uses socket factory to create a new tls on connect`() {
  67. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  68. host = HOST
  69. port = PORT
  70. useTls = true
  71. }, profileConfig, BehaviourConfig(), null))
  72. client.socketFactory = mockSocketFactory
  73. client.connect()
  74. verify(timeout = 500) { mockSocketFactory(client, HOST, PORT, true) }
  75. }
  76. @Test
  77. fun `throws if socket already exists`() {
  78. val client = IrcClientImpl(normalConfig)
  79. client.socketFactory = mockSocketFactory
  80. client.connect()
  81. assertThrows<IllegalStateException> {
  82. client.connect()
  83. }
  84. }
  85. @Test
  86. fun `emits connection events with local time`() = runBlocking {
  87. currentTimeProvider = { TestConstants.time }
  88. val connectingSlot = slot<ServerConnecting>()
  89. val connectedSlot = slot<ServerConnected>()
  90. every { mockEventHandler.invoke(capture(connectingSlot)) } just Runs
  91. every { mockEventHandler.invoke(capture(connectedSlot)) } just Runs
  92. val client = IrcClientImpl(normalConfig)
  93. client.socketFactory = mockSocketFactory
  94. client.onEvent(mockEventHandler)
  95. client.connect()
  96. verify(timeout = 500) {
  97. mockEventHandler(ofType<ServerConnecting>())
  98. mockEventHandler(ofType<ServerConnected>())
  99. }
  100. assertEquals(TestConstants.time, connectingSlot.captured.metadata.time)
  101. assertEquals(TestConstants.time, connectedSlot.captured.metadata.time)
  102. }
  103. @Test
  104. fun `sends basic connection strings`() = runBlocking {
  105. val client = IrcClientImpl(normalConfig)
  106. client.socketFactory = mockSocketFactory
  107. client.connect()
  108. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  109. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  110. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  111. }
  112. @Test
  113. fun `sends password first, when present`() = runBlocking {
  114. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  115. host = HOST
  116. port = PORT
  117. password = PASSWORD
  118. }, profileConfig, BehaviourConfig(), null))
  119. client.socketFactory = mockSocketFactory
  120. client.connect()
  121. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  122. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  123. }
  124. @Test
  125. fun `sends events to provided event handler`() {
  126. val client = IrcClientImpl(normalConfig)
  127. client.socketFactory = mockSocketFactory
  128. client.onEvent(mockEventHandler)
  129. GlobalScope.launch {
  130. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  131. }
  132. client.connect()
  133. verify(timeout = 500) {
  134. mockEventHandler(ofType<ServerWelcome>())
  135. }
  136. }
  137. @Test
  138. fun `gets case mapping from server features`() {
  139. val client = IrcClientImpl(normalConfig)
  140. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  141. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  142. }
  143. @Test
  144. fun `indicates if user is local user or not`() {
  145. val client = IrcClientImpl(normalConfig)
  146. client.serverState.localNickname = "[acidBurn]"
  147. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  148. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  149. }
  150. @Test
  151. fun `indicates if nickname is local user or not`() {
  152. val client = IrcClientImpl(normalConfig)
  153. client.serverState.localNickname = "[acidBurn]"
  154. assertTrue(client.isLocalUser("{acidBurn}"))
  155. assertFalse(client.isLocalUser("acid-Burn"))
  156. }
  157. @Test
  158. fun `uses current case mapping to check local user`() {
  159. val client = IrcClientImpl(normalConfig)
  160. client.serverState.localNickname = "[acidBurn]"
  161. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  162. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  163. }
  164. @Test
  165. @SuppressWarnings("deprecation")
  166. fun `sends text to socket`() = runBlocking {
  167. val client = IrcClientImpl(normalConfig)
  168. client.socketFactory = mockSocketFactory
  169. client.connect()
  170. client.send("testing 123")
  171. assertLineReceived("testing 123")
  172. }
  173. @Test
  174. fun `sends structured text to socket`() = runBlocking {
  175. val client = IrcClientImpl(normalConfig)
  176. client.socketFactory = mockSocketFactory
  177. client.connect()
  178. client.send("testing", "123", "456")
  179. assertLineReceived("testing 123 456")
  180. }
  181. @Test
  182. fun `echoes message event when behaviour is set and cap is unsupported`() = runBlocking {
  183. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  184. val client = IrcClientImpl(config)
  185. client.socketFactory = mockSocketFactory
  186. val slot = slot<MessageReceived>()
  187. val mockkEventHandler = mockk<(IrcEvent) -> Unit>(relaxed = true)
  188. every { mockkEventHandler(capture(slot)) } just Runs
  189. client.onEvent(mockkEventHandler)
  190. client.connect()
  191. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  192. assertTrue(slot.isCaptured)
  193. val event = slot.captured
  194. assertEquals("#thegibson", event.target)
  195. assertEquals("Mess with the best, die like the rest", event.message)
  196. assertEquals(NICK, event.user.nickname)
  197. assertEquals(TestConstants.time, event.metadata.time)
  198. }
  199. @Test
  200. fun `does not echo message event when behaviour is set and cap is supported`() = runBlocking {
  201. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  202. val client = IrcClientImpl(config)
  203. client.socketFactory = mockSocketFactory
  204. client.serverState.capabilities.enabledCapabilities[Capability.EchoMessages] = ""
  205. client.connect()
  206. client.onEvent(mockEventHandler)
  207. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  208. verify(inverse = true) {
  209. mockEventHandler(ofType<MessageReceived>())
  210. }
  211. }
  212. @Test
  213. fun `does not echo message event when behaviour is unset`() = runBlocking {
  214. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = false }, null)
  215. val client = IrcClientImpl(config)
  216. client.socketFactory = mockSocketFactory
  217. client.connect()
  218. client.onEvent(mockEventHandler)
  219. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  220. verify(inverse = true) {
  221. mockEventHandler(ofType<MessageReceived>())
  222. } }
  223. @Test
  224. fun `sends structured text to socket with tags`() = runBlocking {
  225. val client = IrcClientImpl(normalConfig)
  226. client.socketFactory = mockSocketFactory
  227. client.connect()
  228. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  229. assertLineReceived("@account=acidB testing 123 456")
  230. }
  231. @Test
  232. fun `asynchronously sends text to socket without label if cap is missing`() = runBlocking {
  233. val client = IrcClientImpl(normalConfig)
  234. client.socketFactory = mockSocketFactory
  235. client.connect()
  236. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  237. assertLineReceived("testing 123")
  238. }
  239. @Test
  240. fun `asynchronously sends text to socket with added tags and label`() = runBlocking {
  241. generateLabel = { "abc123" }
  242. val client = IrcClientImpl(normalConfig)
  243. client.socketFactory = mockSocketFactory
  244. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  245. client.connect()
  246. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  247. assertLineReceived("@draft/label=abc123 testing 123")
  248. }
  249. @Test
  250. fun `asynchronously sends tagged text to socket with label`() = runBlocking {
  251. generateLabel = { "abc123" }
  252. val client = IrcClientImpl(normalConfig)
  253. client.socketFactory = mockSocketFactory
  254. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  255. client.connect()
  256. client.sendAsync(tagMap(MessageTag.AccountName to "x"), "testing", arrayOf("123")) { false }
  257. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  258. }
  259. @Test
  260. fun `disconnects the socket`() = runBlocking {
  261. val client = IrcClientImpl(normalConfig)
  262. client.socketFactory = mockSocketFactory
  263. client.connect()
  264. client.disconnect()
  265. verify(timeout = 500) {
  266. mockSocket.disconnect()
  267. }
  268. }
  269. @Test
  270. @ObsoleteCoroutinesApi
  271. fun `sends messages in order`() = runBlocking {
  272. val client = IrcClientImpl(normalConfig)
  273. client.socketFactory = mockSocketFactory
  274. client.connect()
  275. (0..100).forEach { client.send("TEST", "$it") }
  276. assertEquals(100, withTimeoutOrNull(500) {
  277. var next = 0
  278. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  279. assertEquals("TEST $next", line)
  280. if (++next == 100) {
  281. break
  282. }
  283. }
  284. next
  285. })
  286. }
  287. @Test
  288. fun `defaults local nickname to profile`() {
  289. val client = IrcClientImpl(normalConfig)
  290. assertEquals(NICK, client.serverState.localNickname)
  291. }
  292. @Test
  293. fun `defaults server name to host name`() {
  294. val client = IrcClientImpl(normalConfig)
  295. assertEquals(HOST, client.serverState.serverName)
  296. }
  297. @Test
  298. fun `exposes behaviour config`() {
  299. val client = IrcClientImpl(IrcClientConfig(
  300. ServerConfig().apply { host = HOST },
  301. profileConfig,
  302. BehaviourConfig().apply { requestModesOnJoin = true },
  303. null))
  304. assertTrue(client.behaviour.requestModesOnJoin)
  305. }
  306. @Test
  307. fun `reset clears all state`() {
  308. with(IrcClientImpl(normalConfig)) {
  309. userState += User("acidBurn")
  310. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  311. serverState.serverName = "root.$HOST"
  312. reset()
  313. assertEquals(0, userState.count())
  314. assertEquals(0, channelState.count())
  315. assertEquals(HOST, serverState.serverName)
  316. }
  317. }
  318. @Test
  319. fun `sends connect error when host is unresolvable`() = runBlocking {
  320. every { mockSocket.connect() } throws UnresolvedAddressException()
  321. with(IrcClientImpl(normalConfig)) {
  322. socketFactory = mockSocketFactory
  323. withTimeout(500) {
  324. launch {
  325. delay(50)
  326. connect()
  327. }
  328. val event = waitForEvent<ServerConnectionError>()
  329. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  330. }
  331. }
  332. }
  333. @Test
  334. fun `sends connect error when tls certificate is bad`() = runBlocking {
  335. every { mockSocket.connect() } throws CertificateException("Boooo")
  336. with(IrcClientImpl(normalConfig)) {
  337. socketFactory = mockSocketFactory
  338. withTimeout(500) {
  339. launch {
  340. delay(50)
  341. connect()
  342. }
  343. val event = waitForEvent<ServerConnectionError>()
  344. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  345. assertEquals("Boooo", event.details)
  346. }
  347. }
  348. }
  349. @Test
  350. fun `identifies channels that have a prefix in the chantypes feature`() {
  351. with(IrcClientImpl(normalConfig)) {
  352. serverState.features[ServerFeature.ChannelTypes] = "&~"
  353. assertTrue(isChannel("&dumpsterdiving"))
  354. assertTrue(isChannel("~hacktheplanet"))
  355. assertFalse(isChannel("#root"))
  356. assertFalse(isChannel("acidBurn"))
  357. assertFalse(isChannel(""))
  358. assertFalse(isChannel("acidBurn#~"))
  359. }
  360. }
  361. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  362. val mutex = Mutex(true)
  363. val value = AtomicReference<T>()
  364. onEvent {
  365. if (it is T) {
  366. value.set(it)
  367. mutex.unlock()
  368. }
  369. }
  370. mutex.lock()
  371. return value.get()
  372. }
  373. private suspend fun assertLineReceived(expected: String) {
  374. assertEquals(true, withTimeoutOrNull(500) {
  375. for (line in sendLineChannel.map { String(it) }) {
  376. println(line)
  377. if (line == expected) {
  378. return@withTimeoutOrNull true
  379. }
  380. }
  381. false
  382. }) { "Expected to receive $expected" }
  383. }
  384. }