You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.currentTimeProvider
  8. import com.dmdirc.ktirc.util.generateLabel
  9. import io.mockk.*
  10. import kotlinx.coroutines.*
  11. import kotlinx.coroutines.channels.Channel
  12. import kotlinx.coroutines.channels.filter
  13. import kotlinx.coroutines.channels.map
  14. import kotlinx.coroutines.sync.Mutex
  15. import org.junit.jupiter.api.Assertions.*
  16. import org.junit.jupiter.api.BeforeEach
  17. import org.junit.jupiter.api.Test
  18. import org.junit.jupiter.api.assertThrows
  19. import java.net.UnknownHostException
  20. import java.nio.channels.UnresolvedAddressException
  21. import java.security.cert.CertificateException
  22. import java.util.concurrent.atomic.AtomicReference
  23. @ExperimentalCoroutinesApi
  24. internal class IrcClientImplTest {
  25. companion object {
  26. private const val HOST = "thegibson.com"
  27. private const val HOST2 = "irc.thegibson.com"
  28. private const val IP = "127.0.13.37"
  29. private const val PORT = 12345
  30. private const val NICK = "AcidBurn"
  31. private const val REAL_NAME = "Kate Libby"
  32. private const val USER_NAME = "acidb"
  33. private const val PASSWORD = "HackThePlanet"
  34. }
  35. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  36. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  37. private val mockSocket = mockk<LineBufferedSocket> {
  38. every { receiveChannel } returns readLineChannel
  39. every { sendChannel } returns sendLineChannel
  40. }
  41. private val mockSocketFactory = mockk<(CoroutineScope, String, String, Int, Boolean) -> LineBufferedSocket> {
  42. every { this@mockk.invoke(any(), eq(HOST), eq(IP), eq(PORT), any()) } returns mockSocket
  43. every { this@mockk.invoke(any(), eq(HOST2), any(), eq(PORT), any()) } returns mockSocket
  44. }
  45. private val mockResolver = mockk<(String) -> Collection<ResolveResult>> {
  46. every { this@mockk.invoke(HOST) } returns listOf(ResolveResult(IP, false))
  47. }
  48. private val mockEventHandler = mockk<(IrcEvent) -> Unit> {
  49. every { this@mockk.invoke(any()) } just Runs
  50. }
  51. private val profileConfig = ProfileConfig().apply {
  52. nickname = NICK
  53. realName = REAL_NAME
  54. username = USER_NAME
  55. }
  56. private val serverConfig = ServerConfig().apply {
  57. host = HOST
  58. port = PORT
  59. useTls = false
  60. }
  61. private val normalConfig = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig(), null)
  62. @BeforeEach
  63. fun setUp() {
  64. currentTimeProvider = { TestConstants.time }
  65. }
  66. @Test
  67. fun `uses socket factory to create a new socket on connect`() {
  68. val client = IrcClientImpl(normalConfig)
  69. client.socketFactory = mockSocketFactory
  70. client.resolver = mockResolver
  71. client.connect()
  72. verify(timeout = 500) { mockSocketFactory(client, HOST, IP, PORT, false) }
  73. }
  74. @Test
  75. fun `uses socket factory to create a new tls on connect`() {
  76. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  77. host = HOST
  78. port = PORT
  79. useTls = true
  80. }, profileConfig, BehaviourConfig(), null))
  81. client.socketFactory = mockSocketFactory
  82. client.resolver = mockResolver
  83. client.connect()
  84. verify(timeout = 500) { mockSocketFactory(client, HOST, IP, PORT, true) }
  85. }
  86. @Test
  87. fun `prefers ipv6 addresses if behaviour is enabled`() {
  88. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  89. host = HOST2
  90. port = PORT
  91. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  92. every { mockResolver(HOST2) } returns listOf(
  93. ResolveResult(IP, false),
  94. ResolveResult("::13:37", true),
  95. ResolveResult("0.0.0.0", false)
  96. )
  97. client.socketFactory = mockSocketFactory
  98. client.resolver = mockResolver
  99. client.connect()
  100. verify(timeout = 500) { mockSocketFactory(client, HOST2, "::13:37", PORT, true) }
  101. }
  102. @Test
  103. fun `falls back to ipv4 if no ipv6 addresses are available`() {
  104. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  105. host = HOST2
  106. port = PORT
  107. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  108. every { mockResolver(HOST2) } returns listOf(
  109. ResolveResult("0.0.0.0", false)
  110. )
  111. client.socketFactory = mockSocketFactory
  112. client.resolver = mockResolver
  113. client.connect()
  114. verify(timeout = 500) { mockSocketFactory(client, HOST2, "0.0.0.0", PORT, true) }
  115. }
  116. @Test
  117. fun `prefers ipv4 addresses if ipv6 behaviour is disabled`() {
  118. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  119. host = HOST2
  120. port = PORT
  121. }, profileConfig, BehaviourConfig().apply { preferIPv6 = false }, null))
  122. every { mockResolver(HOST2) } returns listOf(
  123. ResolveResult("::13:37", true),
  124. ResolveResult("::313:37", true),
  125. ResolveResult("0.0.0.0", false)
  126. )
  127. client.socketFactory = mockSocketFactory
  128. client.resolver = mockResolver
  129. client.connect()
  130. verify(timeout = 500) { mockSocketFactory(client, HOST2, "0.0.0.0", PORT, true) }
  131. }
  132. @Test
  133. fun `falls back to ipv6 if no ipv4 addresses available`() {
  134. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  135. host = HOST2
  136. port = PORT
  137. }, profileConfig, BehaviourConfig().apply { preferIPv6 = false }, null))
  138. every { mockResolver(HOST2) } returns listOf(
  139. ResolveResult("::13:37", true)
  140. )
  141. client.socketFactory = mockSocketFactory
  142. client.resolver = mockResolver
  143. client.connect()
  144. verify(timeout = 500) { mockSocketFactory(client, HOST2, "::13:37", PORT, true) }
  145. }
  146. @Test
  147. fun `raises error if dns fails`() {
  148. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  149. host = HOST2
  150. }, profileConfig, BehaviourConfig().apply { preferIPv6 = true }, null))
  151. every { mockResolver(HOST2) } throws UnknownHostException("oops")
  152. client.socketFactory = mockSocketFactory
  153. client.resolver = mockResolver
  154. client.onEvent(mockEventHandler)
  155. client.connect()
  156. verify(timeout = 500) {
  157. mockEventHandler(match { it is ServerConnectionError && it.error == ConnectionError.UnresolvableAddress })
  158. }
  159. }
  160. @Test
  161. fun `throws if socket already exists`() {
  162. val client = IrcClientImpl(normalConfig)
  163. client.socketFactory = mockSocketFactory
  164. client.resolver = mockResolver
  165. client.connect()
  166. assertThrows<IllegalStateException> {
  167. client.connect()
  168. }
  169. }
  170. @Test
  171. fun `emits connection events with local time`() = runBlocking {
  172. currentTimeProvider = { TestConstants.time }
  173. val connectingSlot = slot<ServerConnecting>()
  174. val connectedSlot = slot<ServerConnected>()
  175. every { mockEventHandler.invoke(capture(connectingSlot)) } just Runs
  176. every { mockEventHandler.invoke(capture(connectedSlot)) } just Runs
  177. val client = IrcClientImpl(normalConfig)
  178. client.socketFactory = mockSocketFactory
  179. client.resolver = mockResolver
  180. client.onEvent(mockEventHandler)
  181. client.connect()
  182. verify(timeout = 500) {
  183. mockEventHandler(ofType<ServerConnecting>())
  184. mockEventHandler(ofType<ServerConnected>())
  185. }
  186. assertEquals(TestConstants.time, connectingSlot.captured.metadata.time)
  187. assertEquals(TestConstants.time, connectedSlot.captured.metadata.time)
  188. }
  189. @Test
  190. fun `sends basic connection strings`() = runBlocking {
  191. val client = IrcClientImpl(normalConfig)
  192. client.socketFactory = mockSocketFactory
  193. client.resolver = mockResolver
  194. client.connect()
  195. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  196. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  197. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  198. }
  199. @Test
  200. fun `sends password first, when present`() = runBlocking {
  201. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  202. host = HOST
  203. port = PORT
  204. password = PASSWORD
  205. }, profileConfig, BehaviourConfig(), null))
  206. client.socketFactory = mockSocketFactory
  207. client.resolver = mockResolver
  208. client.connect()
  209. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  210. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  211. }
  212. @Test
  213. fun `sends events to provided event handler`() {
  214. val client = IrcClientImpl(normalConfig)
  215. client.socketFactory = mockSocketFactory
  216. client.resolver = mockResolver
  217. client.onEvent(mockEventHandler)
  218. GlobalScope.launch {
  219. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  220. }
  221. client.connect()
  222. verify(timeout = 500) {
  223. mockEventHandler(ofType<ServerWelcome>())
  224. }
  225. }
  226. @Test
  227. fun `gets case mapping from server features`() {
  228. val client = IrcClientImpl(normalConfig)
  229. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  230. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  231. }
  232. @Test
  233. fun `indicates if user is local user or not`() {
  234. val client = IrcClientImpl(normalConfig)
  235. client.serverState.localNickname = "[acidBurn]"
  236. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  237. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  238. }
  239. @Test
  240. fun `indicates if nickname is local user or not`() {
  241. val client = IrcClientImpl(normalConfig)
  242. client.serverState.localNickname = "[acidBurn]"
  243. assertTrue(client.isLocalUser("{acidBurn}"))
  244. assertFalse(client.isLocalUser("acid-Burn"))
  245. }
  246. @Test
  247. fun `uses current case mapping to check local user`() {
  248. val client = IrcClientImpl(normalConfig)
  249. client.serverState.localNickname = "[acidBurn]"
  250. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  251. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  252. }
  253. @Test
  254. @SuppressWarnings("deprecation")
  255. fun `sends text to socket`() = runBlocking {
  256. val client = IrcClientImpl(normalConfig)
  257. client.socketFactory = mockSocketFactory
  258. client.resolver = mockResolver
  259. client.connect()
  260. client.send("testing 123")
  261. assertLineReceived("testing 123")
  262. }
  263. @Test
  264. fun `sends structured text to socket`() = runBlocking {
  265. val client = IrcClientImpl(normalConfig)
  266. client.socketFactory = mockSocketFactory
  267. client.resolver = mockResolver
  268. client.connect()
  269. client.send("testing", "123", "456")
  270. assertLineReceived("testing 123 456")
  271. }
  272. @Test
  273. fun `echoes message event when behaviour is set and cap is unsupported`() = runBlocking {
  274. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  275. val client = IrcClientImpl(config)
  276. client.socketFactory = mockSocketFactory
  277. client.resolver = mockResolver
  278. val slot = slot<MessageReceived>()
  279. val mockkEventHandler = mockk<(IrcEvent) -> Unit>(relaxed = true)
  280. every { mockkEventHandler(capture(slot)) } just Runs
  281. client.onEvent(mockkEventHandler)
  282. client.connect()
  283. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  284. assertTrue(slot.isCaptured)
  285. val event = slot.captured
  286. assertEquals("#thegibson", event.target)
  287. assertEquals("Mess with the best, die like the rest", event.message)
  288. assertEquals(NICK, event.user.nickname)
  289. assertEquals(TestConstants.time, event.metadata.time)
  290. }
  291. @Test
  292. fun `does not echo message event when behaviour is set and cap is supported`() = runBlocking {
  293. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  294. val client = IrcClientImpl(config)
  295. client.socketFactory = mockSocketFactory
  296. client.resolver = mockResolver
  297. client.serverState.capabilities.enabledCapabilities[Capability.EchoMessages] = ""
  298. client.connect()
  299. client.onEvent(mockEventHandler)
  300. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  301. verify(inverse = true) {
  302. mockEventHandler(ofType<MessageReceived>())
  303. }
  304. }
  305. @Test
  306. fun `does not echo message event when behaviour is unset`() = runBlocking {
  307. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = false }, null)
  308. val client = IrcClientImpl(config)
  309. client.socketFactory = mockSocketFactory
  310. client.resolver = mockResolver
  311. client.connect()
  312. client.onEvent(mockEventHandler)
  313. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  314. verify(inverse = true) {
  315. mockEventHandler(ofType<MessageReceived>())
  316. }
  317. }
  318. @Test
  319. fun `sends structured text to socket with tags`() = runBlocking {
  320. val client = IrcClientImpl(normalConfig)
  321. client.socketFactory = mockSocketFactory
  322. client.resolver = mockResolver
  323. client.connect()
  324. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  325. assertLineReceived("@account=acidB testing 123 456")
  326. }
  327. @Test
  328. fun `asynchronously sends text to socket without label if cap is missing`() = runBlocking {
  329. val client = IrcClientImpl(normalConfig)
  330. client.socketFactory = mockSocketFactory
  331. client.resolver = mockResolver
  332. client.connect()
  333. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  334. assertLineReceived("testing 123")
  335. }
  336. @Test
  337. fun `asynchronously sends text to socket with added tags and label`() = runBlocking {
  338. generateLabel = { "abc123" }
  339. val client = IrcClientImpl(normalConfig)
  340. client.socketFactory = mockSocketFactory
  341. client.resolver = mockResolver
  342. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  343. client.connect()
  344. client.sendAsync(tagMap(), "testing", arrayOf("123")) { false }
  345. assertLineReceived("@draft/label=abc123 testing 123")
  346. }
  347. @Test
  348. fun `asynchronously sends tagged text to socket with label`() = runBlocking {
  349. generateLabel = { "abc123" }
  350. val client = IrcClientImpl(normalConfig)
  351. client.socketFactory = mockSocketFactory
  352. client.resolver = mockResolver
  353. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  354. client.connect()
  355. client.sendAsync(tagMap(MessageTag.AccountName to "x"), "testing", arrayOf("123")) { false }
  356. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  357. }
  358. @Test
  359. fun `disconnects the socket`() = runBlocking {
  360. val client = IrcClientImpl(normalConfig)
  361. client.socketFactory = mockSocketFactory
  362. client.resolver = mockResolver
  363. client.connect()
  364. client.disconnect()
  365. verify(timeout = 500) {
  366. mockSocket.disconnect()
  367. }
  368. }
  369. @Test
  370. @ObsoleteCoroutinesApi
  371. fun `sends messages in order`() = runBlocking {
  372. val client = IrcClientImpl(normalConfig)
  373. client.socketFactory = mockSocketFactory
  374. client.resolver = mockResolver
  375. client.connect()
  376. (0..100).forEach { client.send("TEST", "$it") }
  377. assertEquals(100, withTimeoutOrNull(500) {
  378. var next = 0
  379. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  380. assertEquals("TEST $next", line)
  381. if (++next == 100) {
  382. break
  383. }
  384. }
  385. next
  386. })
  387. }
  388. @Test
  389. fun `defaults local nickname to profile`() {
  390. val client = IrcClientImpl(normalConfig)
  391. assertEquals(NICK, client.serverState.localNickname)
  392. }
  393. @Test
  394. fun `defaults server name to host name`() {
  395. val client = IrcClientImpl(normalConfig)
  396. assertEquals(HOST, client.serverState.serverName)
  397. }
  398. @Test
  399. fun `exposes behaviour config`() {
  400. val client = IrcClientImpl(IrcClientConfig(
  401. ServerConfig().apply { host = HOST },
  402. profileConfig,
  403. BehaviourConfig().apply { requestModesOnJoin = true },
  404. null))
  405. assertTrue(client.behaviour.requestModesOnJoin)
  406. }
  407. @Test
  408. fun `reset clears all state`() {
  409. with(IrcClientImpl(normalConfig)) {
  410. userState += User("acidBurn")
  411. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  412. serverState.serverName = "root.$HOST"
  413. reset()
  414. assertEquals(0, userState.count())
  415. assertEquals(0, channelState.count())
  416. assertEquals(HOST, serverState.serverName)
  417. }
  418. }
  419. @Test
  420. fun `sends connect error when host is unresolvable`() = runBlocking {
  421. every { mockSocket.connect() } throws UnresolvedAddressException()
  422. with(IrcClientImpl(normalConfig)) {
  423. socketFactory = mockSocketFactory
  424. resolver = mockResolver
  425. withTimeout(500) {
  426. launch {
  427. delay(50)
  428. connect()
  429. }
  430. val event = waitForEvent<ServerConnectionError>()
  431. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  432. }
  433. }
  434. }
  435. @Test
  436. fun `sends connect error when tls certificate is bad`() = runBlocking {
  437. every { mockSocket.connect() } throws CertificateException("Boooo")
  438. with(IrcClientImpl(normalConfig)) {
  439. socketFactory = mockSocketFactory
  440. resolver = mockResolver
  441. withTimeout(500) {
  442. launch {
  443. delay(50)
  444. connect()
  445. }
  446. val event = waitForEvent<ServerConnectionError>()
  447. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  448. assertEquals("Boooo", event.details)
  449. }
  450. }
  451. }
  452. @Test
  453. fun `identifies channels that have a prefix in the chantypes feature`() {
  454. with(IrcClientImpl(normalConfig)) {
  455. serverState.features[ServerFeature.ChannelTypes] = "&~"
  456. assertTrue(isChannel("&dumpsterdiving"))
  457. assertTrue(isChannel("~hacktheplanet"))
  458. assertFalse(isChannel("#root"))
  459. assertFalse(isChannel("acidBurn"))
  460. assertFalse(isChannel(""))
  461. assertFalse(isChannel("acidBurn#~"))
  462. }
  463. }
  464. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  465. val mutex = Mutex(true)
  466. val value = AtomicReference<T>()
  467. onEvent {
  468. if (it is T) {
  469. value.set(it)
  470. mutex.unlock()
  471. }
  472. }
  473. mutex.lock()
  474. return value.get()
  475. }
  476. private suspend fun assertLineReceived(expected: String) {
  477. assertEquals(true, withTimeoutOrNull(500) {
  478. for (line in sendLineChannel.map { String(it) }) {
  479. println(line)
  480. if (line == expected) {
  481. return@withTimeoutOrNull true
  482. }
  483. }
  484. false
  485. }) { "Expected to receive $expected" }
  486. }
  487. }