You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.messages.tagMap
  6. import com.dmdirc.ktirc.model.*
  7. import com.dmdirc.ktirc.util.currentTimeProvider
  8. import com.dmdirc.ktirc.util.generateLabel
  9. import com.nhaarman.mockitokotlin2.*
  10. import io.ktor.util.KtorExperimentalAPI
  11. import kotlinx.coroutines.*
  12. import kotlinx.coroutines.channels.Channel
  13. import kotlinx.coroutines.channels.filter
  14. import kotlinx.coroutines.channels.map
  15. import kotlinx.coroutines.sync.Mutex
  16. import org.junit.jupiter.api.Assertions.*
  17. import org.junit.jupiter.api.BeforeEach
  18. import org.junit.jupiter.api.Test
  19. import org.junit.jupiter.api.assertThrows
  20. import java.nio.channels.UnresolvedAddressException
  21. import java.security.cert.CertificateException
  22. import java.util.concurrent.atomic.AtomicReference
  23. @KtorExperimentalAPI
  24. @ExperimentalCoroutinesApi
  25. internal class IrcClientImplTest {
  26. companion object {
  27. private const val HOST = "thegibson.com"
  28. private const val PORT = 12345
  29. private const val NICK = "AcidBurn"
  30. private const val REAL_NAME = "Kate Libby"
  31. private const val USER_NAME = "acidb"
  32. private const val PASSWORD = "HackThePlanet"
  33. }
  34. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  35. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  36. private val mockSocket = mock<LineBufferedSocket> {
  37. on { receiveChannel } doReturn readLineChannel
  38. on { sendChannel } doReturn sendLineChannel
  39. }
  40. private val mockSocketFactory = mock<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  41. on { invoke(any(), eq(HOST), eq(PORT), any()) } doReturn mockSocket
  42. }
  43. private val mockEventHandler = mock<(IrcEvent) -> Unit>()
  44. private val profileConfig = ProfileConfig().apply {
  45. nickname = NICK
  46. realName = REAL_NAME
  47. username = USER_NAME
  48. }
  49. private val normalConfig = IrcClientConfig(ServerConfig().apply {
  50. host = HOST
  51. port = PORT
  52. }, profileConfig, BehaviourConfig(), null)
  53. @BeforeEach
  54. fun setUp() {
  55. currentTimeProvider = { TestConstants.time }
  56. }
  57. @Test
  58. fun `uses socket factory to create a new socket on connect`() {
  59. val client = IrcClientImpl(normalConfig)
  60. client.socketFactory = mockSocketFactory
  61. client.connect()
  62. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, false)
  63. }
  64. @Test
  65. fun `uses socket factory to create a new tls on connect`() {
  66. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  67. host = HOST
  68. port = PORT
  69. useTls = true
  70. }, profileConfig, BehaviourConfig(), null))
  71. client.socketFactory = mockSocketFactory
  72. client.connect()
  73. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, true)
  74. }
  75. @Test
  76. fun `throws if socket already exists`() {
  77. val client = IrcClientImpl(normalConfig)
  78. client.socketFactory = mockSocketFactory
  79. client.connect()
  80. assertThrows<IllegalStateException> {
  81. client.connect()
  82. }
  83. }
  84. @Test
  85. fun `emits connection events with local time`() = runBlocking {
  86. currentTimeProvider = { TestConstants.time }
  87. val client = IrcClientImpl(normalConfig)
  88. client.socketFactory = mockSocketFactory
  89. client.onEvent(mockEventHandler)
  90. client.connect()
  91. val captor = argumentCaptor<IrcEvent>()
  92. verify(mockEventHandler, timeout(500).atLeast(2)).invoke(captor.capture())
  93. assertTrue(captor.firstValue is ServerConnecting)
  94. assertEquals(TestConstants.time, captor.firstValue.metadata.time)
  95. assertTrue(captor.secondValue is ServerConnected)
  96. assertEquals(TestConstants.time, captor.secondValue.metadata.time)
  97. }
  98. @Test
  99. fun `sends basic connection strings`() = runBlocking {
  100. val client = IrcClientImpl(normalConfig)
  101. client.socketFactory = mockSocketFactory
  102. client.connect()
  103. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  104. assertEquals("NICK $NICK", String(sendLineChannel.receive()))
  105. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  106. }
  107. @Test
  108. fun `sends password first, when present`() = runBlocking {
  109. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  110. host = HOST
  111. port = PORT
  112. password = PASSWORD
  113. }, profileConfig, BehaviourConfig(), null))
  114. client.socketFactory = mockSocketFactory
  115. client.connect()
  116. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  117. assertEquals("PASS $PASSWORD", String(sendLineChannel.receive()))
  118. }
  119. @Test
  120. fun `sends events to provided event handler`() {
  121. val client = IrcClientImpl(normalConfig)
  122. client.socketFactory = mockSocketFactory
  123. client.onEvent(mockEventHandler)
  124. GlobalScope.launch {
  125. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  126. }
  127. client.connect()
  128. verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
  129. }
  130. @Test
  131. fun `gets case mapping from server features`() {
  132. val client = IrcClientImpl(normalConfig)
  133. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  134. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  135. }
  136. @Test
  137. fun `indicates if user is local user or not`() {
  138. val client = IrcClientImpl(normalConfig)
  139. client.serverState.localNickname = "[acidBurn]"
  140. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  141. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  142. }
  143. @Test
  144. fun `indicates if nickname is local user or not`() {
  145. val client = IrcClientImpl(normalConfig)
  146. client.serverState.localNickname = "[acidBurn]"
  147. assertTrue(client.isLocalUser("{acidBurn}"))
  148. assertFalse(client.isLocalUser("acid-Burn"))
  149. }
  150. @Test
  151. fun `uses current case mapping to check local user`() {
  152. val client = IrcClientImpl(normalConfig)
  153. client.serverState.localNickname = "[acidBurn]"
  154. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  155. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  156. }
  157. @Test
  158. fun `sends text to socket`() = runBlocking {
  159. val client = IrcClientImpl(normalConfig)
  160. client.socketFactory = mockSocketFactory
  161. client.connect()
  162. client.send("testing 123")
  163. assertLineReceived("testing 123")
  164. }
  165. @Test
  166. fun `sends structured text to socket`() = runBlocking {
  167. val client = IrcClientImpl(normalConfig)
  168. client.socketFactory = mockSocketFactory
  169. client.connect()
  170. client.send("testing", "123", "456")
  171. assertLineReceived("testing 123 456")
  172. }
  173. @Test
  174. fun `sends structured text to socket with tags`() = runBlocking {
  175. val client = IrcClientImpl(normalConfig)
  176. client.socketFactory = mockSocketFactory
  177. client.connect()
  178. client.send(tagMap(MessageTag.AccountName to "acidB"), "testing", "123", "456")
  179. assertLineReceived("@account=acidB testing 123 456")
  180. }
  181. @Test
  182. fun `sends text to socket without label if cap is missing`() = runBlocking {
  183. val client = IrcClientImpl(normalConfig)
  184. client.socketFactory = mockSocketFactory
  185. client.connect()
  186. client.sendWithLabel(tagMap(), "testing", "123")
  187. assertLineReceived("testing 123")
  188. }
  189. @Test
  190. fun `sends text to socket with added tags and label`() = runBlocking {
  191. generateLabel = { "abc123" }
  192. val client = IrcClientImpl(normalConfig)
  193. client.socketFactory = mockSocketFactory
  194. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  195. client.connect()
  196. client.sendWithLabel(tagMap(), "testing", "123")
  197. assertLineReceived("@draft/label=abc123 testing 123")
  198. }
  199. @Test
  200. fun `sends tagged text to socket with label`() = runBlocking {
  201. generateLabel = { "abc123" }
  202. val client = IrcClientImpl(normalConfig)
  203. client.socketFactory = mockSocketFactory
  204. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  205. client.connect()
  206. client.sendWithLabel(tagMap(MessageTag.AccountName to "x"), "testing", "123")
  207. assertLineReceived("@account=x;draft/label=abc123 testing 123")
  208. }
  209. @Test
  210. fun `disconnects the socket`() = runBlocking {
  211. val client = IrcClientImpl(normalConfig)
  212. client.socketFactory = mockSocketFactory
  213. client.connect()
  214. client.disconnect()
  215. verify(mockSocket, timeout(500)).disconnect()
  216. }
  217. @Test
  218. @ObsoleteCoroutinesApi
  219. fun `sends messages in order`() = runBlocking {
  220. val client = IrcClientImpl(normalConfig)
  221. client.socketFactory = mockSocketFactory
  222. client.connect()
  223. (0..100).forEach { client.send("TEST", "$it") }
  224. assertEquals(100, withTimeoutOrNull(500) {
  225. var next = 0
  226. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  227. assertEquals("TEST $next", line)
  228. if (++next == 100) {
  229. break
  230. }
  231. }
  232. next
  233. })
  234. }
  235. @Test
  236. fun `defaults local nickname to profile`() {
  237. val client = IrcClientImpl(normalConfig)
  238. assertEquals(NICK, client.serverState.localNickname)
  239. }
  240. @Test
  241. fun `defaults server name to host name`() {
  242. val client = IrcClientImpl(normalConfig)
  243. assertEquals(HOST, client.serverState.serverName)
  244. }
  245. @Test
  246. fun `exposes behaviour config`() {
  247. val client = IrcClientImpl(IrcClientConfig(
  248. ServerConfig().apply { host = HOST },
  249. profileConfig,
  250. BehaviourConfig().apply { requestModesOnJoin = true },
  251. null))
  252. assertTrue(client.behaviour.requestModesOnJoin)
  253. }
  254. @Test
  255. fun `reset clears all state`() {
  256. with(IrcClientImpl(normalConfig)) {
  257. userState += User("acidBurn")
  258. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  259. serverState.serverName = "root.$HOST"
  260. reset()
  261. assertEquals(0, userState.count())
  262. assertEquals(0, channelState.count())
  263. assertEquals(HOST, serverState.serverName)
  264. }
  265. }
  266. @Test
  267. fun `sends connect error when host is unresolvable`() = runBlocking {
  268. whenever(mockSocket.connect()).doThrow(UnresolvedAddressException())
  269. with(IrcClientImpl(normalConfig)) {
  270. socketFactory = mockSocketFactory
  271. withTimeout(500) {
  272. launch {
  273. delay(50)
  274. connect()
  275. }
  276. val event = waitForEvent<ServerConnectionError>()
  277. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  278. }
  279. }
  280. }
  281. @Test
  282. fun `sends connect error when tls certificate is bad`() = runBlocking {
  283. whenever(mockSocket.connect()).doThrow(CertificateException("Boooo"))
  284. with(IrcClientImpl(normalConfig)) {
  285. socketFactory = mockSocketFactory
  286. withTimeout(500) {
  287. launch {
  288. delay(50)
  289. connect()
  290. }
  291. val event = waitForEvent<ServerConnectionError>()
  292. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  293. assertEquals("Boooo", event.details)
  294. }
  295. }
  296. }
  297. @Test
  298. fun `identifies channels that have a prefix in the chantypes feature`() {
  299. with(IrcClientImpl(normalConfig)) {
  300. serverState.features[ServerFeature.ChannelTypes] = "&~"
  301. assertTrue(isChannel("&dumpsterdiving"))
  302. assertTrue(isChannel("~hacktheplanet"))
  303. assertFalse(isChannel("#root"))
  304. assertFalse(isChannel("acidBurn"))
  305. assertFalse(isChannel(""))
  306. assertFalse(isChannel("acidBurn#~"))
  307. }
  308. }
  309. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  310. val mutex = Mutex(true)
  311. val value = AtomicReference<T>()
  312. onEvent {
  313. if (it is T) {
  314. value.set(it)
  315. mutex.unlock()
  316. }
  317. }
  318. mutex.lock()
  319. return value.get()
  320. }
  321. private suspend fun assertLineReceived(expected: String) {
  322. assertEquals(true, withTimeoutOrNull(500) {
  323. for (line in sendLineChannel.map { String(it) }) {
  324. println(line)
  325. if (line == expected) {
  326. return@withTimeoutOrNull true
  327. }
  328. }
  329. false
  330. }) { "Expected to receive $expected" }
  331. }
  332. }