You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.currentTimeProvider
  8. import com.dmdirc.ktirc.util.generateLabel
  9. import io.ktor.util.KtorExperimentalAPI
  10. import io.mockk.*
  11. import kotlinx.coroutines.*
  12. import kotlinx.coroutines.channels.Channel
  13. import kotlinx.coroutines.channels.filter
  14. import kotlinx.coroutines.channels.map
  15. import kotlinx.coroutines.sync.Mutex
  16. import org.junit.jupiter.api.Assertions.*
  17. import org.junit.jupiter.api.BeforeEach
  18. import org.junit.jupiter.api.Test
  19. import org.junit.jupiter.api.assertThrows
  20. import java.nio.channels.UnresolvedAddressException
  21. import java.security.cert.CertificateException
  22. import java.util.concurrent.atomic.AtomicReference
  23. @KtorExperimentalAPI
  24. @ExperimentalCoroutinesApi
  25. internal class IrcClientImplTest {
  26. companion object {
  27. private const val HOST = "thegibson.com"
  28. private const val PORT = 12345
  29. private const val NICK = "AcidBurn"
  30. private const val REAL_NAME = "Kate Libby"
  31. private const val USER_NAME = "acidb"
  32. private const val PASSWORD = "HackThePlanet"
  33. }
  34. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  35. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  36. private val mockSocket = mockk<LineBufferedSocket> {
  37. every { receiveChannel } returns readLineChannel
  38. every { sendChannel } returns sendLineChannel
  39. }
  40. private val mockSocketFactory = mockk<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  41. every { this@mockk.invoke(any(), eq(HOST), eq(PORT), any()) } returns mockSocket
  42. }
  43. private val mockEventHandler = mockk<(IrcEvent) -> Unit> {
  44. every { this@mockk.invoke(any()) } just Runs
  45. }
  46. private val profileConfig = ProfileConfig().apply {
  47. nickname = NICK
  48. realName = REAL_NAME
  49. username = USER_NAME
  50. }
  51. private val serverConfig = ServerConfig().apply {
  52. host = HOST
  53. port = PORT
  54. }
  55. private val normalConfig = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig(), null)
  56. @BeforeEach
  57. fun setUp() {
  58. currentTimeProvider = { TestConstants.time }
  59. }
  60. @Test
  61. fun `uses socket factory to create a new socket on connect`() {
  62. val client = IrcClientImpl(normalConfig)
  63. client.socketFactory = mockSocketFactory
  64. client.connect()
  65. verify(timeout = 500) { mockSocketFactory(client, HOST, PORT, false) }
  66. }
  67. @Test
  68. fun `uses socket factory to create a new tls on connect`() {
  69. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  70. host = HOST
  71. port = PORT
  72. useTls = true
  73. }, profileConfig, BehaviourConfig(), null))
  74. client.socketFactory = mockSocketFactory
  75. client.connect()
  76. verify(timeout = 500) { mockSocketFactory(client, HOST, PORT, true) }
  77. }
  78. @Test
  79. fun `throws if socket already exists`() {
  80. val client = IrcClientImpl(normalConfig)
  81. client.socketFactory = mockSocketFactory
  82. client.connect()
  83. assertThrows<IllegalStateException> {
  84. client.connect()
  85. }
  86. }
  87. @Test
  88. fun `emits connection events with local time`() = runBlocking {
  89. currentTimeProvider = { TestConstants.time }
  90. val connectingSlot = slot<ServerConnecting>()
  91. val connectedSlot = slot<ServerConnected>()
  92. every { mockEventHandler.invoke(capture(connectingSlot)) } just Runs
  93. every { mockEventHandler.invoke(capture(connectedSlot)) } just Runs
  94. val client = IrcClientImpl(normalConfig)
  95. client.socketFactory = mockSocketFactory
  96. client.onEvent(mockEventHandler)
  97. client.connect()
  98. verify(timeout = 500) {
  99. mockEventHandler(ofType<ServerConnecting>())
  100. mockEventHandler(ofType<ServerConnected>())
  101. }
  102. assertEquals(TestConstants.time, connectingSlot.captured.metadata.time)
  103. assertEquals(TestConstants.time, connectedSlot.captured.metadata.time)
  104. }
  105. @Test
  106. fun `sends basic connection strings`() = runBlocking {
  107. val client = IrcClientImpl(normalConfig)
  108. client.socketFactory = mockSocketFactory
  109. client.connect()
  110. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  111. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  112. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  113. }
  114. @Test
  115. fun `sends password first, when present`() = runBlocking {
  116. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  117. host = HOST
  118. port = PORT
  119. password = PASSWORD
  120. }, profileConfig, BehaviourConfig(), null))
  121. client.socketFactory = mockSocketFactory
  122. client.connect()
  123. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  124. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  125. }
  126. @Test
  127. fun `sends events to provided event handler`() {
  128. val client = IrcClientImpl(normalConfig)
  129. client.socketFactory = mockSocketFactory
  130. client.onEvent(mockEventHandler)
  131. GlobalScope.launch {
  132. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  133. }
  134. client.connect()
  135. verify(timeout = 500) {
  136. mockEventHandler(ofType<ServerWelcome>())
  137. }
  138. }
  139. @Test
  140. fun `gets case mapping from server features`() {
  141. val client = IrcClientImpl(normalConfig)
  142. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  143. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  144. }
  145. @Test
  146. fun `indicates if user is local user or not`() {
  147. val client = IrcClientImpl(normalConfig)
  148. client.serverState.localNickname = "[acidBurn]"
  149. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  150. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  151. }
  152. @Test
  153. fun `indicates if nickname is local user or not`() {
  154. val client = IrcClientImpl(normalConfig)
  155. client.serverState.localNickname = "[acidBurn]"
  156. assertTrue(client.isLocalUser("{acidBurn}"))
  157. assertFalse(client.isLocalUser("acid-Burn"))
  158. }
  159. @Test
  160. fun `uses current case mapping to check local user`() {
  161. val client = IrcClientImpl(normalConfig)
  162. client.serverState.localNickname = "[acidBurn]"
  163. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  164. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  165. }
  166. @Test
  167. @SuppressWarnings("deprecation")
  168. fun `sends text to socket`() = runBlocking {
  169. val client = IrcClientImpl(normalConfig)
  170. client.socketFactory = mockSocketFactory
  171. client.connect()
  172. client.send("testing 123")
  173. assertLineReceived("testing 123")
  174. }
  175. @Test
  176. fun `sends structured text to socket`() = runBlocking {
  177. val client = IrcClientImpl(normalConfig)
  178. client.socketFactory = mockSocketFactory
  179. client.connect()
  180. client.send("testing", "123", "456")
  181. assertLineReceived("testing 123 456")
  182. }
  183. @Test
  184. fun `echoes message event when behaviour is set and cap is unsupported`() = runBlocking {
  185. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  186. val client = IrcClientImpl(config)
  187. client.socketFactory = mockSocketFactory
  188. val slot = slot<MessageReceived>()
  189. val mockkEventHandler = mockk<(IrcEvent) -> Unit>(relaxed = true)
  190. every { mockkEventHandler(capture(slot)) } just Runs
  191. client.onEvent(mockkEventHandler)
  192. client.connect()
  193. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  194. assertTrue(slot.isCaptured)
  195. val event = slot.captured
  196. assertEquals("#thegibson", event.target)
  197. assertEquals("Mess with the best, die like the rest", event.message)
  198. assertEquals(NICK, event.user.nickname)
  199. assertEquals(TestConstants.time, event.metadata.time)
  200. }
  201. @Test
  202. fun `does not echo message event when behaviour is set and cap is supported`() = runBlocking {
  203. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  204. val client = IrcClientImpl(config)
  205. client.socketFactory = mockSocketFactory
  206. client.serverState.capabilities.enabledCapabilities[Capability.EchoMessages] = ""
  207. client.connect()
  208. client.onEvent(mockEventHandler)
  209. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  210. verify(inverse = true) {
  211. mockEventHandler(ofType<MessageReceived>())
  212. }
  213. }
  214. @Test
  215. fun `does not echo message event when behaviour is unset`() = runBlocking {
  216. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = false }, null)
  217. val client = IrcClientImpl(config)
  218. client.socketFactory = mockSocketFactory
  219. client.connect()
  220. client.onEvent(mockEventHandler)
  221. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  222. verify(inverse = true) {
  223. mockEventHandler(ofType<MessageReceived>())
  224. } }
  225. @Test
  226. fun `sends structured text to socket with tags`() = runBlocking {
  227. val client = IrcClientImpl(normalConfig)
  228. client.socketFactory = mockSocketFactory
  229. client.connect()
  230. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  231. assertLineReceived("@account=acidB testing 123 456")
  232. }
  233. @Test
  234. fun `asynchronously sends text to socket without label if cap is missing`() = runBlocking {
  235. val client = IrcClientImpl(normalConfig)
  236. client.socketFactory = mockSocketFactory
  237. client.connect()
  238. client.sendAsync(tagMap(), "testing", "123")
  239. assertLineReceived("testing 123")
  240. }
  241. @Test
  242. fun `asynchronously sends text to socket with added tags and label`() = runBlocking {
  243. generateLabel = { "abc123" }
  244. val client = IrcClientImpl(normalConfig)
  245. client.socketFactory = mockSocketFactory
  246. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  247. client.connect()
  248. client.sendAsync(tagMap(), "testing", "123")
  249. assertLineReceived("@draft/label=abc123 testing 123")
  250. }
  251. @Test
  252. fun `asynchronously sends tagged text to socket with label`() = runBlocking {
  253. generateLabel = { "abc123" }
  254. val client = IrcClientImpl(normalConfig)
  255. client.socketFactory = mockSocketFactory
  256. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  257. client.connect()
  258. client.sendAsync(tagMap(MessageTag.AccountName to "x"), "testing", "123")
  259. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  260. }
  261. @Test
  262. fun `disconnects the socket`() = runBlocking {
  263. val client = IrcClientImpl(normalConfig)
  264. client.socketFactory = mockSocketFactory
  265. client.connect()
  266. client.disconnect()
  267. verify(timeout = 500) {
  268. mockSocket.disconnect()
  269. }
  270. }
  271. @Test
  272. @ObsoleteCoroutinesApi
  273. fun `sends messages in order`() = runBlocking {
  274. val client = IrcClientImpl(normalConfig)
  275. client.socketFactory = mockSocketFactory
  276. client.connect()
  277. (0..100).forEach { client.send("TEST", "$it") }
  278. assertEquals(100, withTimeoutOrNull(500) {
  279. var next = 0
  280. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  281. assertEquals("TEST $next", line)
  282. if (++next == 100) {
  283. break
  284. }
  285. }
  286. next
  287. })
  288. }
  289. @Test
  290. fun `defaults local nickname to profile`() {
  291. val client = IrcClientImpl(normalConfig)
  292. assertEquals(NICK, client.serverState.localNickname)
  293. }
  294. @Test
  295. fun `defaults server name to host name`() {
  296. val client = IrcClientImpl(normalConfig)
  297. assertEquals(HOST, client.serverState.serverName)
  298. }
  299. @Test
  300. fun `exposes behaviour config`() {
  301. val client = IrcClientImpl(IrcClientConfig(
  302. ServerConfig().apply { host = HOST },
  303. profileConfig,
  304. BehaviourConfig().apply { requestModesOnJoin = true },
  305. null))
  306. assertTrue(client.behaviour.requestModesOnJoin)
  307. }
  308. @Test
  309. fun `reset clears all state`() {
  310. with(IrcClientImpl(normalConfig)) {
  311. userState += User("acidBurn")
  312. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  313. serverState.serverName = "root.$HOST"
  314. reset()
  315. assertEquals(0, userState.count())
  316. assertEquals(0, channelState.count())
  317. assertEquals(HOST, serverState.serverName)
  318. }
  319. }
  320. @Test
  321. fun `sends connect error when host is unresolvable`() = runBlocking {
  322. every { mockSocket.connect() } throws UnresolvedAddressException()
  323. with(IrcClientImpl(normalConfig)) {
  324. socketFactory = mockSocketFactory
  325. withTimeout(500) {
  326. launch {
  327. delay(50)
  328. connect()
  329. }
  330. val event = waitForEvent<ServerConnectionError>()
  331. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  332. }
  333. }
  334. }
  335. @Test
  336. fun `sends connect error when tls certificate is bad`() = runBlocking {
  337. every { mockSocket.connect() } throws CertificateException("Boooo")
  338. with(IrcClientImpl(normalConfig)) {
  339. socketFactory = mockSocketFactory
  340. withTimeout(500) {
  341. launch {
  342. delay(50)
  343. connect()
  344. }
  345. val event = waitForEvent<ServerConnectionError>()
  346. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  347. assertEquals("Boooo", event.details)
  348. }
  349. }
  350. }
  351. @Test
  352. fun `identifies channels that have a prefix in the chantypes feature`() {
  353. with(IrcClientImpl(normalConfig)) {
  354. serverState.features[ServerFeature.ChannelTypes] = "&~"
  355. assertTrue(isChannel("&dumpsterdiving"))
  356. assertTrue(isChannel("~hacktheplanet"))
  357. assertFalse(isChannel("#root"))
  358. assertFalse(isChannel("acidBurn"))
  359. assertFalse(isChannel(""))
  360. assertFalse(isChannel("acidBurn#~"))
  361. }
  362. }
  363. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  364. val mutex = Mutex(true)
  365. val value = AtomicReference<T>()
  366. onEvent {
  367. if (it is T) {
  368. value.set(it)
  369. mutex.unlock()
  370. }
  371. }
  372. mutex.lock()
  373. return value.get()
  374. }
  375. private suspend fun assertLineReceived(expected: String) {
  376. assertEquals(true, withTimeoutOrNull(500) {
  377. for (line in sendLineChannel.map { String(it) }) {
  378. println(line)
  379. if (line == expected) {
  380. return@withTimeoutOrNull true
  381. }
  382. }
  383. false
  384. }) { "Expected to receive $expected" }
  385. }
  386. }