You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.currentTimeProvider
  8. import com.dmdirc.ktirc.util.generateLabel
  9. import com.nhaarman.mockitokotlin2.*
  10. import io.ktor.util.KtorExperimentalAPI
  11. import io.mockk.*
  12. import kotlinx.coroutines.*
  13. import kotlinx.coroutines.channels.Channel
  14. import kotlinx.coroutines.channels.filter
  15. import kotlinx.coroutines.channels.map
  16. import kotlinx.coroutines.sync.Mutex
  17. import org.junit.jupiter.api.Assertions.*
  18. import org.junit.jupiter.api.BeforeEach
  19. import org.junit.jupiter.api.Test
  20. import org.junit.jupiter.api.assertThrows
  21. import java.nio.channels.UnresolvedAddressException
  22. import java.security.cert.CertificateException
  23. import java.util.concurrent.atomic.AtomicReference
  24. @KtorExperimentalAPI
  25. @ExperimentalCoroutinesApi
  26. internal class IrcClientImplTest {
  27. companion object {
  28. private const val HOST = "thegibson.com"
  29. private const val PORT = 12345
  30. private const val NICK = "AcidBurn"
  31. private const val REAL_NAME = "Kate Libby"
  32. private const val USER_NAME = "acidb"
  33. private const val PASSWORD = "HackThePlanet"
  34. }
  35. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  36. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  37. private val mockSocket = mock<LineBufferedSocket> {
  38. on { receiveChannel } doReturn readLineChannel
  39. on { sendChannel } doReturn sendLineChannel
  40. }
  41. private val mockSocketFactory = mock<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  42. on { invoke(any(), eq(HOST), eq(PORT), any()) } doReturn mockSocket
  43. }
  44. private val mockEventHandler = mock<(IrcEvent) -> Unit>()
  45. private val profileConfig = ProfileConfig().apply {
  46. nickname = NICK
  47. realName = REAL_NAME
  48. username = USER_NAME
  49. }
  50. private val serverConfig = ServerConfig().apply {
  51. host = HOST
  52. port = PORT
  53. }
  54. private val normalConfig = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig(), null)
  55. @BeforeEach
  56. fun setUp() {
  57. currentTimeProvider = { TestConstants.time }
  58. }
  59. @Test
  60. fun `uses socket factory to create a new socket on connect`() {
  61. val client = IrcClientImpl(normalConfig)
  62. client.socketFactory = mockSocketFactory
  63. client.connect()
  64. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, false)
  65. }
  66. @Test
  67. fun `uses socket factory to create a new tls on connect`() {
  68. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  69. host = HOST
  70. port = PORT
  71. useTls = true
  72. }, profileConfig, BehaviourConfig(), null))
  73. client.socketFactory = mockSocketFactory
  74. client.connect()
  75. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, true)
  76. }
  77. @Test
  78. fun `throws if socket already exists`() {
  79. val client = IrcClientImpl(normalConfig)
  80. client.socketFactory = mockSocketFactory
  81. client.connect()
  82. assertThrows<IllegalStateException> {
  83. client.connect()
  84. }
  85. }
  86. @Test
  87. fun `emits connection events with local time`() = runBlocking {
  88. currentTimeProvider = { TestConstants.time }
  89. val client = IrcClientImpl(normalConfig)
  90. client.socketFactory = mockSocketFactory
  91. client.onEvent(mockEventHandler)
  92. client.connect()
  93. val captor = argumentCaptor<IrcEvent>()
  94. verify(mockEventHandler, timeout(500).atLeast(2)).invoke(captor.capture())
  95. assertTrue(captor.firstValue is ServerConnecting)
  96. assertEquals(TestConstants.time, captor.firstValue.metadata.time)
  97. assertTrue(captor.secondValue is ServerConnected)
  98. assertEquals(TestConstants.time, captor.secondValue.metadata.time)
  99. }
  100. @Test
  101. fun `sends basic connection strings`() = runBlocking {
  102. val client = IrcClientImpl(normalConfig)
  103. client.socketFactory = mockSocketFactory
  104. client.connect()
  105. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  106. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  107. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  108. }
  109. @Test
  110. fun `sends password first, when present`() = runBlocking {
  111. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  112. host = HOST
  113. port = PORT
  114. password = PASSWORD
  115. }, profileConfig, BehaviourConfig(), null))
  116. client.socketFactory = mockSocketFactory
  117. client.connect()
  118. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  119. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  120. }
  121. @Test
  122. fun `sends events to provided event handler`() {
  123. val client = IrcClientImpl(normalConfig)
  124. client.socketFactory = mockSocketFactory
  125. client.onEvent(mockEventHandler)
  126. GlobalScope.launch {
  127. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  128. }
  129. client.connect()
  130. verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
  131. }
  132. @Test
  133. fun `gets case mapping from server features`() {
  134. val client = IrcClientImpl(normalConfig)
  135. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  136. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  137. }
  138. @Test
  139. fun `indicates if user is local user or not`() {
  140. val client = IrcClientImpl(normalConfig)
  141. client.serverState.localNickname = "[acidBurn]"
  142. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  143. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  144. }
  145. @Test
  146. fun `indicates if nickname is local user or not`() {
  147. val client = IrcClientImpl(normalConfig)
  148. client.serverState.localNickname = "[acidBurn]"
  149. assertTrue(client.isLocalUser("{acidBurn}"))
  150. assertFalse(client.isLocalUser("acid-Burn"))
  151. }
  152. @Test
  153. fun `uses current case mapping to check local user`() {
  154. val client = IrcClientImpl(normalConfig)
  155. client.serverState.localNickname = "[acidBurn]"
  156. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  157. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  158. }
  159. @Test
  160. @SuppressWarnings("deprecation")
  161. fun `sends text to socket`() = runBlocking {
  162. val client = IrcClientImpl(normalConfig)
  163. client.socketFactory = mockSocketFactory
  164. client.connect()
  165. client.send("testing 123")
  166. assertLineReceived("testing 123")
  167. }
  168. @Test
  169. fun `sends structured text to socket`() = runBlocking {
  170. val client = IrcClientImpl(normalConfig)
  171. client.socketFactory = mockSocketFactory
  172. client.connect()
  173. client.send("testing", "123", "456")
  174. assertLineReceived("testing 123 456")
  175. }
  176. @Test
  177. fun `echoes message event when behaviour is set and cap is unsupported`() = runBlocking {
  178. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  179. val client = IrcClientImpl(config)
  180. client.socketFactory = mockSocketFactory
  181. val slot = slot<MessageReceived>()
  182. val mockkEventHandler = mockk<(IrcEvent) -> Unit>(relaxed = true)
  183. every { mockkEventHandler(capture(slot)) } just Runs
  184. client.onEvent(mockkEventHandler)
  185. client.connect()
  186. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  187. assertTrue(slot.isCaptured)
  188. val event = slot.captured
  189. assertEquals("#thegibson", event.target)
  190. assertEquals("Mess with the best, die like the rest", event.message)
  191. assertEquals(NICK, event.user.nickname)
  192. assertEquals(TestConstants.time, event.metadata.time)
  193. }
  194. @Test
  195. fun `does not echo message event when behaviour is set and cap is supported`() = runBlocking {
  196. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = true }, null)
  197. val client = IrcClientImpl(config)
  198. client.socketFactory = mockSocketFactory
  199. client.serverState.capabilities.enabledCapabilities[Capability.EchoMessages] = ""
  200. client.connect()
  201. client.onEvent(mockEventHandler)
  202. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  203. verify(mockEventHandler, never()).invoke(isA<MessageReceived>())
  204. }
  205. @Test
  206. fun `does not echo message event when behaviour is unset`() = runBlocking {
  207. val config = IrcClientConfig(serverConfig, profileConfig, BehaviourConfig().apply { alwaysEchoMessages = false }, null)
  208. val client = IrcClientImpl(config)
  209. client.socketFactory = mockSocketFactory
  210. client.connect()
  211. client.onEvent(mockEventHandler)
  212. client.send("PRIVMSG", "#thegibson", "Mess with the best, die like the rest")
  213. verify(mockEventHandler, never()).invoke(isA<MessageReceived>())
  214. }
  215. @Test
  216. fun `sends structured text to socket with tags`() = runBlocking {
  217. val client = IrcClientImpl(normalConfig)
  218. client.socketFactory = mockSocketFactory
  219. client.connect()
  220. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  221. assertLineReceived("@account=acidB testing 123 456")
  222. }
  223. @Test
  224. fun `sends text to socket without label if cap is missing`() = runBlocking {
  225. val client = IrcClientImpl(normalConfig)
  226. client.socketFactory = mockSocketFactory
  227. client.connect()
  228. client.sendWithLabel(tagMap(), "testing", "123")
  229. assertLineReceived("testing 123")
  230. }
  231. @Test
  232. fun `sends text to socket with added tags and label`() = runBlocking {
  233. generateLabel = { "abc123" }
  234. val client = IrcClientImpl(normalConfig)
  235. client.socketFactory = mockSocketFactory
  236. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  237. client.connect()
  238. client.sendWithLabel(tagMap(), "testing", "123")
  239. assertLineReceived("@draft/label=abc123 testing 123")
  240. }
  241. @Test
  242. fun `sends tagged text to socket with label`() = runBlocking {
  243. generateLabel = { "abc123" }
  244. val client = IrcClientImpl(normalConfig)
  245. client.socketFactory = mockSocketFactory
  246. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  247. client.connect()
  248. client.sendWithLabel(tagMap(MessageTag.AccountName to "x"), "testing", "123")
  249. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  250. }
  251. @Test
  252. fun `disconnects the socket`() = runBlocking {
  253. val client = IrcClientImpl(normalConfig)
  254. client.socketFactory = mockSocketFactory
  255. client.connect()
  256. client.disconnect()
  257. verify(mockSocket, timeout(500)).disconnect()
  258. }
  259. @Test
  260. @ObsoleteCoroutinesApi
  261. fun `sends messages in order`() = runBlocking {
  262. val client = IrcClientImpl(normalConfig)
  263. client.socketFactory = mockSocketFactory
  264. client.connect()
  265. (0..100).forEach { client.send("TEST", "$it") }
  266. assertEquals(100, withTimeoutOrNull(500) {
  267. var next = 0
  268. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  269. assertEquals("TEST $next", line)
  270. if (++next == 100) {
  271. break
  272. }
  273. }
  274. next
  275. })
  276. }
  277. @Test
  278. fun `defaults local nickname to profile`() {
  279. val client = IrcClientImpl(normalConfig)
  280. assertEquals(NICK, client.serverState.localNickname)
  281. }
  282. @Test
  283. fun `defaults server name to host name`() {
  284. val client = IrcClientImpl(normalConfig)
  285. assertEquals(HOST, client.serverState.serverName)
  286. }
  287. @Test
  288. fun `exposes behaviour config`() {
  289. val client = IrcClientImpl(IrcClientConfig(
  290. ServerConfig().apply { host = HOST },
  291. profileConfig,
  292. BehaviourConfig().apply { requestModesOnJoin = true },
  293. null))
  294. assertTrue(client.behaviour.requestModesOnJoin)
  295. }
  296. @Test
  297. fun `reset clears all state`() {
  298. with(IrcClientImpl(normalConfig)) {
  299. userState += User("acidBurn")
  300. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  301. serverState.serverName = "root.$HOST"
  302. reset()
  303. assertEquals(0, userState.count())
  304. assertEquals(0, channelState.count())
  305. assertEquals(HOST, serverState.serverName)
  306. }
  307. }
  308. @Test
  309. fun `sends connect error when host is unresolvable`() = runBlocking {
  310. whenever(mockSocket.connect()).doThrow(UnresolvedAddressException())
  311. with(IrcClientImpl(normalConfig)) {
  312. socketFactory = mockSocketFactory
  313. withTimeout(500) {
  314. launch {
  315. delay(50)
  316. connect()
  317. }
  318. val event = waitForEvent<ServerConnectionError>()
  319. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  320. }
  321. }
  322. }
  323. @Test
  324. fun `sends connect error when tls certificate is bad`() = runBlocking {
  325. whenever(mockSocket.connect()).doThrow(CertificateException("Boooo"))
  326. with(IrcClientImpl(normalConfig)) {
  327. socketFactory = mockSocketFactory
  328. withTimeout(500) {
  329. launch {
  330. delay(50)
  331. connect()
  332. }
  333. val event = waitForEvent<ServerConnectionError>()
  334. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  335. assertEquals("Boooo", event.details)
  336. }
  337. }
  338. }
  339. @Test
  340. fun `identifies channels that have a prefix in the chantypes feature`() {
  341. with(IrcClientImpl(normalConfig)) {
  342. serverState.features[ServerFeature.ChannelTypes] = "&~"
  343. assertTrue(isChannel("&dumpsterdiving"))
  344. assertTrue(isChannel("~hacktheplanet"))
  345. assertFalse(isChannel("#root"))
  346. assertFalse(isChannel("acidBurn"))
  347. assertFalse(isChannel(""))
  348. assertFalse(isChannel("acidBurn#~"))
  349. }
  350. }
  351. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  352. val mutex = Mutex(true)
  353. val value = AtomicReference<T>()
  354. onEvent {
  355. if (it is T) {
  356. value.set(it)
  357. mutex.unlock()
  358. }
  359. }
  360. mutex.lock()
  361. return value.get()
  362. }
  363. private suspend fun assertLineReceived(expected: String) {
  364. assertEquals(true, withTimeoutOrNull(500) {
  365. for (line in sendLineChannel.map { String(it) }) {
  366. println(line)
  367. if (line == expected) {
  368. return@withTimeoutOrNull true
  369. }
  370. }
  371. false
  372. }) { "Expected to receive $expected" }
  373. }
  374. }