You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.RemoveIn
  8. import com.dmdirc.ktirc.util.currentTimeProvider
  9. import com.dmdirc.ktirc.util.generateLabel
  10. import io.mockk.*
  11. import kotlinx.coroutines.*
  12. import kotlinx.coroutines.channels.Channel
  13. import kotlinx.coroutines.channels.filter
  14. import kotlinx.coroutines.channels.map
  15. import kotlinx.coroutines.sync.Mutex
  16. import org.junit.jupiter.api.Assertions.*
  17. import org.junit.jupiter.api.BeforeEach
  18. import org.junit.jupiter.api.Test
  19. import org.junit.jupiter.api.assertThrows
  20. import java.net.UnknownHostException
  21. import java.nio.channels.UnresolvedAddressException
  22. import java.security.cert.CertificateException
  23. import java.util.concurrent.atomic.AtomicReference
  24. @ExperimentalCoroutinesApi
  25. internal class IrcClientImplTest {
  26. companion object {
  27. private const val HOST = "thegibson.com"
  28. private const val HOST2 = "irc.thegibson.com"
  29. private const val IP = "127.0.13.37"
  30. private const val PORT = 12345
  31. private const val NICK = "AcidBurn"
  32. private const val REAL_NAME = "Kate Libby"
  33. private const val USER_NAME = "acidb"
  34. private const val PASSWORD = "HackThePlanet"
  35. }
  36. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  37. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  38. private val mockSocket = mockk<LineBufferedSocket> {
  39. every { receiveChannel } returns readLineChannel
  40. every { sendChannel } returns sendLineChannel
  41. }
  42. private val mockSocketFactory = mockk<(CoroutineScope, String, String, Int, Boolean) -> LineBufferedSocket> {
  43. every { this@mockk.invoke(any(), eq(HOST), eq(IP), eq(PORT), any()) } returns mockSocket
  44. every { this@mockk.invoke(any(), eq(HOST2), any(), eq(PORT), any()) } returns mockSocket
  45. }
  46. private val mockResolver = mockk<(String) -> Collection<ResolveResult>> {
  47. every { this@mockk.invoke(HOST) } returns listOf(ResolveResult(IP, false))
  48. }
  49. private val mockEventHandler = mockk<(IrcEvent) -> Unit> {
  50. every { this@mockk.invoke(any()) } just Runs
  51. }
  52. private val profileConfig = ProfileConfig().apply {
  53. nickname = NICK
  54. realName = REAL_NAME
  55. username = USER_NAME
  56. }
  57. private val serverConfig = ServerConfig().apply {
  58. host = HOST
  59. port = PORT
  60. useTls = false
  61. }
  62. private val normalConfig = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig(), null)
  63. @BeforeEach
  64. fun setUp() {
  65. currentTimeProvider = { TestConstants.time }
  66. }
  67. @Test
  68. fun `uses socket factory to create a new socket on connect`() {
  69. val client = IrcClientImpl(normalConfig)
  70. client.socketFactory = mockSocketFactory
  71. client.resolver = mockResolver
  72. client.connect()
  73. verify(timeout = 500) { mockSocketFactory(client, HOST, IP, PORT, false) }
  74. }
  75. @Test
  76. fun `uses socket factory to create a new tls on connect`() {
  77. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  78. host = HOST
  79. port = PORT
  80. useTls = true
  81. }, profileConfig, BehaviourConfig(), null))
  82. client.socketFactory = mockSocketFactory
  83. client.resolver = mockResolver
  84. client.connect()
  85. verify(timeout = 500) { mockSocketFactory(client, HOST, IP, PORT, true) }
  86. }
  87. @Test
  88. fun `prefers ipv6 addresses if behaviour is enabled`() {
  89. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  90. host = HOST2
  91. port = PORT
  92. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  93. every { mockResolver(HOST2) } returns listOf(
  94. ResolveResult(IP, false),
  95. ResolveResult("::13:37", true),
  96. ResolveResult("0.0.0.0", false)
  97. )
  98. client.socketFactory = mockSocketFactory
  99. client.resolver = mockResolver
  100. client.connect()
  101. verify(timeout = 500) { mockSocketFactory(client, HOST2, "::13:37", PORT, true) }
  102. }
  103. @Test
  104. fun `falls back to ipv4 if no ipv6 addresses are available`() {
  105. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  106. host = HOST2
  107. port = PORT
  108. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  109. every { mockResolver(HOST2) } returns listOf(
  110. ResolveResult("0.0.0.0", false)
  111. )
  112. client.socketFactory = mockSocketFactory
  113. client.resolver = mockResolver
  114. client.connect()
  115. verify(timeout = 500) { mockSocketFactory(client, HOST2, "0.0.0.0", PORT, true) }
  116. }
  117. @Test
  118. fun `prefers ipv4 addresses if ipv6 behaviour is disabled`() {
  119. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  120. host = HOST2
  121. port = PORT
  122. }, profileConfig, BehaviourConfig().apply { preferIPv6 = false }, null))
  123. every { mockResolver(HOST2) } returns listOf(
  124. ResolveResult("::13:37", true),
  125. ResolveResult("::313:37", true),
  126. ResolveResult("0.0.0.0", false)
  127. )
  128. client.socketFactory = mockSocketFactory
  129. client.resolver = mockResolver
  130. client.connect()
  131. verify(timeout = 500) { mockSocketFactory(client, HOST2, "0.0.0.0", PORT, true) }
  132. }
  133. @Test
  134. fun `falls back to ipv6 if no ipv4 addresses available`() {
  135. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  136. host = HOST2
  137. port = PORT
  138. }, profileConfig, BehaviourConfig().apply { preferIPv6 = false }, null))
  139. every { mockResolver(HOST2) } returns listOf(
  140. ResolveResult("::13:37", true)
  141. )
  142. client.socketFactory = mockSocketFactory
  143. client.resolver = mockResolver
  144. client.connect()
  145. verify(timeout = 500) { mockSocketFactory(client, HOST2, "::13:37", PORT, true) }
  146. }
  147. @Test
  148. fun `raises error if dns fails`() {
  149. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  150. host = HOST2
  151. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  152. every { mockResolver(HOST2) } throws UnknownHostException("oops")
  153. client.socketFactory = mockSocketFactory
  154. client.resolver = mockResolver
  155. client.onEvent(mockEventHandler)
  156. client.connect()
  157. verify(timeout = 500) {
  158. mockEventHandler(match { it is ServerConnectionError && it.error == ConnectionError.UnresolvableAddress })
  159. }
  160. }
  161. @Test
  162. fun `throws if socket already exists`() {
  163. val client = IrcClientImpl(normalConfig)
  164. client.socketFactory = mockSocketFactory
  165. client.resolver = mockResolver
  166. client.connect()
  167. assertThrows<IllegalStateException> {
  168. client.connect()
  169. }
  170. }
  171. @Test
  172. fun `emits connection events with local time`() = runBlocking {
  173. currentTimeProvider = { TestConstants.time }
  174. val connectingSlot = slot<ServerConnecting>()
  175. val connectedSlot = slot<ServerConnected>()
  176. every { mockEventHandler.invoke(capture(connectingSlot)) } just Runs
  177. every { mockEventHandler.invoke(capture(connectedSlot)) } just Runs
  178. val client = IrcClientImpl(normalConfig)
  179. client.socketFactory = mockSocketFactory
  180. client.resolver = mockResolver
  181. client.onEvent(mockEventHandler)
  182. client.connect()
  183. verify(timeout = 500) {
  184. mockEventHandler(ofType<ServerConnecting>())
  185. mockEventHandler(ofType<ServerConnected>())
  186. }
  187. assertEquals(TestConstants.time, connectingSlot.captured.metadata.time)
  188. assertEquals(TestConstants.time, connectedSlot.captured.metadata.time)
  189. }
  190. @Test
  191. fun `sends basic connection strings`() = runBlocking {
  192. val client = IrcClientImpl(normalConfig)
  193. client.socketFactory = mockSocketFactory
  194. client.resolver = mockResolver
  195. client.connect()
  196. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  197. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  198. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  199. }
  200. @Test
  201. fun `sends password first, when present`() = runBlocking {
  202. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  203. host = HOST
  204. port = PORT
  205. password = PASSWORD
  206. }, profileConfig, BehaviourConfig(), null))
  207. client.socketFactory = mockSocketFactory
  208. client.resolver = mockResolver
  209. client.connect()
  210. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  211. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  212. }
  213. @Test
  214. fun `sends events to provided event handler`() {
  215. val client = IrcClientImpl(normalConfig)
  216. client.socketFactory = mockSocketFactory
  217. client.resolver = mockResolver
  218. client.onEvent(mockEventHandler)
  219. GlobalScope.launch {
  220. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  221. }
  222. client.connect()
  223. verify(timeout = 500) {
  224. mockEventHandler(ofType<ServerWelcome>())
  225. }
  226. }
  227. @Test
  228. fun `gets case mapping from server features`() {
  229. val client = IrcClientImpl(normalConfig)
  230. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  231. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  232. }
  233. @Test
  234. fun `indicates if user is local user or not`() {
  235. val client = IrcClientImpl(normalConfig)
  236. client.localUser.nickname = "[acidBurn]"
  237. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  238. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  239. }
  240. @Test
  241. fun `indicates if nickname is local user or not`() {
  242. val client = IrcClientImpl(normalConfig)
  243. client.localUser.nickname = "[acidBurn]"
  244. assertTrue(client.isLocalUser("{acidBurn}"))
  245. assertFalse(client.isLocalUser("acid-Burn"))
  246. }
  247. @Test
  248. fun `uses current case mapping to check local user`() {
  249. val client = IrcClientImpl(normalConfig)
  250. client.localUser.nickname = "[acidBurn]"
  251. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  252. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  253. }
  254. @Test
  255. @Deprecated("Tests deprecated method")
  256. @RemoveIn("2.0.0")
  257. fun `sends text to socket`() = runBlocking {
  258. val client = IrcClientImpl(normalConfig)
  259. client.socketFactory = mockSocketFactory
  260. client.resolver = mockResolver
  261. client.connect()
  262. client.send("testing 123")
  263. assertLineReceived("testing 123")
  264. }
  265. @Test
  266. fun `sends structured text to socket`() = runBlocking {
  267. val client = IrcClientImpl(normalConfig)
  268. client.socketFactory = mockSocketFactory
  269. client.resolver = mockResolver
  270. client.connect()
  271. client.send("testing", "123", "456")
  272. assertLineReceived("testing 123 456")
  273. }
  274. @Test
  275. fun `echoes message event when behaviour is set and cap is unsupported`() = runBlocking {
  276. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  277. val client = IrcClientImpl(config)
  278. client.socketFactory = mockSocketFactory
  279. client.resolver = mockResolver
  280. val slot = slot<MessageReceived>()
  281. val mockkEventHandler = mockk<(IrcEvent) -> Unit>(relaxed = true)
  282. every { mockkEventHandler(capture(slot)) } just Runs
  283. client.onEvent(mockkEventHandler)
  284. client.connect()
  285. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  286. assertTrue(slot.isCaptured)
  287. val event = slot.captured
  288. assertEquals("#thegibson", event.target)
  289. assertEquals("Mess with the best, die like the rest", event.message)
  290. assertEquals(NICK, event.user.nickname)
  291. assertEquals(TestConstants.time, event.metadata.time)
  292. }
  293. @Test
  294. fun `does not echo message event when behaviour is set and cap is supported`() = runBlocking {
  295. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  296. val client = IrcClientImpl(config)
  297. client.socketFactory = mockSocketFactory
  298. client.resolver = mockResolver
  299. client.serverState.capabilities.enabledCapabilities[Capability.EchoMessages] = ""
  300. client.connect()
  301. client.onEvent(mockEventHandler)
  302. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  303. verify(inverse = true) {
  304. mockEventHandler(ofType<MessageReceived>())
  305. }
  306. }
  307. @Test
  308. fun `does not echo message event when behaviour is unset`() = runBlocking {
  309. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = false }, null)
  310. val client = IrcClientImpl(config)
  311. client.socketFactory = mockSocketFactory
  312. client.resolver = mockResolver
  313. client.connect()
  314. client.onEvent(mockEventHandler)
  315. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  316. verify(inverse = true) {
  317. mockEventHandler(ofType<MessageReceived>())
  318. }
  319. }
  320. @Test
  321. fun `sends structured text to socket with tags`() = runBlocking {
  322. val client = IrcClientImpl(normalConfig)
  323. client.socketFactory = mockSocketFactory
  324. client.resolver = mockResolver
  325. client.connect()
  326. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  327. assertLineReceived("@account=acidB testing 123 456")
  328. }
  329. @Test
  330. fun `asynchronously sends text to socket without label if cap is missing`() = runBlocking {
  331. val client = IrcClientImpl(normalConfig)
  332. client.socketFactory = mockSocketFactory
  333. client.resolver = mockResolver
  334. client.connect()
  335. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  336. assertLineReceived("testing 123")
  337. }
  338. @Test
  339. fun `asynchronously sends text to socket with added tags and label`() = runBlocking {
  340. generateLabel = { "abc123" }
  341. val client = IrcClientImpl(normalConfig)
  342. client.socketFactory = mockSocketFactory
  343. client.resolver = mockResolver
  344. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  345. client.connect()
  346. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  347. assertLineReceived("@draft/label=abc123 testing 123")
  348. }
  349. @Test
  350. fun `asynchronously sends tagged text to socket with label`() = runBlocking {
  351. generateLabel = { "abc123" }
  352. val client = IrcClientImpl(normalConfig)
  353. client.socketFactory = mockSocketFactory
  354. client.resolver = mockResolver
  355. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  356. client.connect()
  357. client.sendAsync(tagMap(MessageTag.AccountName to "x"), "testing", arrayOf("123")) { false }
  358. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  359. }
  360. @Test
  361. fun `disconnects the socket`() = runBlocking {
  362. val client = IrcClientImpl(normalConfig)
  363. client.socketFactory = mockSocketFactory
  364. client.resolver = mockResolver
  365. client.connect()
  366. launch {
  367. delay(50)
  368. readLineChannel.close()
  369. sendLineChannel.close()
  370. }
  371. client.disconnect()
  372. verify(timeout = 500) {
  373. mockSocket.disconnect()
  374. }
  375. }
  376. @Test
  377. @ObsoleteCoroutinesApi
  378. fun `sends messages in order`() = runBlocking {
  379. val client = IrcClientImpl(normalConfig)
  380. client.socketFactory = mockSocketFactory
  381. client.resolver = mockResolver
  382. client.connect()
  383. (0..100).forEach { client.send("TEST", "$it") }
  384. assertEquals(100, withTimeoutOrNull(500) {
  385. var next = 0
  386. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  387. assertEquals("TEST $next", line)
  388. if (++next == 100) {
  389. break
  390. }
  391. }
  392. next
  393. })
  394. }
  395. @Test
  396. @Deprecated("Tests deprecated method")
  397. @RemoveIn("3.0.0")
  398. fun `defaults local nickname to profile`() {
  399. val client = IrcClientImpl(normalConfig)
  400. assertEquals(NICK, client.serverState.localNickname)
  401. }
  402. @Test
  403. fun `defaults local user to nickname in profile`() {
  404. val client = IrcClientImpl(normalConfig)
  405. assertEquals(User(NICK), client.localUser)
  406. }
  407. @Test
  408. fun `defaults server name to host name`() {
  409. val client = IrcClientImpl(normalConfig)
  410. assertEquals(HOST, client.serverState.serverName)
  411. }
  412. @Test
  413. fun `exposes behaviour config`() {
  414. val client = IrcClientImpl(IrcClientConfig(
  415. ServerConfig().apply { host = HOST },
  416. profileConfig,
  417. BehaviourConfig().apply { requestModesOnJoin = true },
  418. null))
  419. assertTrue(client.behaviour.requestModesOnJoin)
  420. }
  421. @Test
  422. fun `reset clears all state`() {
  423. with(IrcClientImpl(normalConfig)) {
  424. userState += User("zeroCool")
  425. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  426. serverState.serverName = "root.$HOST"
  427. localUser.awayMessage = "Hacking the planet"
  428. reset()
  429. assertEquals(1, userState.count())
  430. assertEquals(0, channelState.count())
  431. assertEquals(HOST, serverState.serverName)
  432. assertEquals(User("AcidBurn"), localUser)
  433. }
  434. }
  435. @Test
  436. fun `sends connect error when host is unresolvable`() = runBlocking {
  437. every { mockSocket.connect() } throws UnresolvedAddressException()
  438. with(IrcClientImpl(normalConfig)) {
  439. socketFactory = mockSocketFactory
  440. resolver = mockResolver
  441. withTimeout(500) {
  442. launch {
  443. delay(50)
  444. connect()
  445. }
  446. val event = waitForEvent<ServerConnectionError>()
  447. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  448. }
  449. }
  450. }
  451. @Test
  452. fun `sends connect error when tls certificate is bad`() = runBlocking {
  453. every { mockSocket.connect() } throws CertificateException("Boooo")
  454. with(IrcClientImpl(normalConfig)) {
  455. socketFactory = mockSocketFactory
  456. resolver = mockResolver
  457. withTimeout(500) {
  458. launch {
  459. delay(50)
  460. connect()
  461. }
  462. val event = waitForEvent<ServerConnectionError>()
  463. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  464. assertEquals("Boooo", event.details)
  465. }
  466. }
  467. }
  468. @Test
  469. fun `identifies channels that have a prefix in the chantypes feature`() {
  470. with(IrcClientImpl(normalConfig)) {
  471. serverState.features[ServerFeature.ChannelTypes] = "&~"
  472. assertTrue(isChannel("&dumpsterdiving"))
  473. assertTrue(isChannel("~hacktheplanet"))
  474. assertFalse(isChannel("#root"))
  475. assertFalse(isChannel("acidBurn"))
  476. assertFalse(isChannel(""))
  477. assertFalse(isChannel("acidBurn#~"))
  478. }
  479. }
  480. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  481. val mutex = Mutex(true)
  482. val value = AtomicReference<T>()
  483. onEvent {
  484. if (it is T) {
  485. value.set(it)
  486. mutex.unlock()
  487. }
  488. }
  489. mutex.lock()
  490. return value.get()
  491. }
  492. private suspend fun assertLineReceived(expected: String) {
  493. assertEquals(true, withTimeoutOrNull(500) {
  494. for (line in sendLineChannel.map { String(it) }) {
  495. println(line)
  496. if (line == expected) {
  497. return@withTimeoutOrNull true
  498. }
  499. }
  500. false
  501. }) { "Expected to receive $expected" }
  502. }
  503. }