You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.model.*
  6. import com.dmdirc.ktirc.util.currentTimeProvider
  7. import com.dmdirc.ktirc.util.generateLabel
  8. import com.nhaarman.mockitokotlin2.*
  9. import io.ktor.util.KtorExperimentalAPI
  10. import kotlinx.coroutines.*
  11. import kotlinx.coroutines.channels.Channel
  12. import kotlinx.coroutines.channels.filter
  13. import kotlinx.coroutines.channels.map
  14. import kotlinx.coroutines.sync.Mutex
  15. import org.junit.jupiter.api.Assertions.*
  16. import org.junit.jupiter.api.BeforeEach
  17. import org.junit.jupiter.api.Test
  18. import org.junit.jupiter.api.assertThrows
  19. import java.nio.channels.UnresolvedAddressException
  20. import java.security.cert.CertificateException
  21. import java.util.concurrent.atomic.AtomicReference
  22. @KtorExperimentalAPI
  23. @ExperimentalCoroutinesApi
  24. internal class IrcClientImplTest {
  25. companion object {
  26. private const val HOST = "thegibson.com"
  27. private const val PORT = 12345
  28. private const val NICK = "AcidBurn"
  29. private const val REAL_NAME = "Kate Libby"
  30. private const val USER_NAME = "acidb"
  31. private const val PASSWORD = "HackThePlanet"
  32. }
  33. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  34. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  35. private val mockSocket = mock<LineBufferedSocket> {
  36. on { receiveChannel } doReturn readLineChannel
  37. on { sendChannel } doReturn sendLineChannel
  38. }
  39. private val mockSocketFactory = mock<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  40. on { invoke(any(), eq(HOST), eq(PORT), any()) } doReturn mockSocket
  41. }
  42. private val mockEventHandler = mock<(IrcEvent) -> Unit>()
  43. private val profileConfig = ProfileConfig().apply {
  44. nickname = NICK
  45. realName = REAL_NAME
  46. username = USER_NAME
  47. }
  48. private val normalConfig = IrcClientConfig(ServerConfig().apply {
  49. host = HOST
  50. port = PORT
  51. }, profileConfig, BehaviourConfig(), null)
  52. @BeforeEach
  53. fun setUp() {
  54. currentTimeProvider = { TestConstants.time }
  55. }
  56. @Test
  57. fun `uses socket factory to create a new socket on connect`() {
  58. val client = IrcClientImpl(normalConfig)
  59. client.socketFactory = mockSocketFactory
  60. client.connect()
  61. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, false)
  62. }
  63. @Test
  64. fun `uses socket factory to create a new tls on connect`() {
  65. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  66. host = HOST
  67. port = PORT
  68. useTls = true
  69. }, profileConfig, BehaviourConfig(), null))
  70. client.socketFactory = mockSocketFactory
  71. client.connect()
  72. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, true)
  73. }
  74. @Test
  75. fun `throws if socket already exists`() {
  76. val client = IrcClientImpl(normalConfig)
  77. client.socketFactory = mockSocketFactory
  78. client.connect()
  79. assertThrows<IllegalStateException> {
  80. client.connect()
  81. }
  82. }
  83. @Test
  84. fun `emits connection events with local time`() = runBlocking {
  85. currentTimeProvider = { TestConstants.time }
  86. val client = IrcClientImpl(normalConfig)
  87. client.socketFactory = mockSocketFactory
  88. client.onEvent(mockEventHandler)
  89. client.connect()
  90. val captor = argumentCaptor<IrcEvent>()
  91. verify(mockEventHandler, timeout(500).atLeast(2)).invoke(captor.capture())
  92. assertTrue(captor.firstValue is ServerConnecting)
  93. assertEquals(TestConstants.time, captor.firstValue.metadata.time)
  94. assertTrue(captor.secondValue is ServerConnected)
  95. assertEquals(TestConstants.time, captor.secondValue.metadata.time)
  96. }
  97. @Test
  98. fun `sends basic connection strings`() = runBlocking {
  99. val client = IrcClientImpl(normalConfig)
  100. client.socketFactory = mockSocketFactory
  101. client.connect()
  102. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  103. assertEquals("NICK :$NICK", String(sendLineChannel.receive()))
  104. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  105. }
  106. @Test
  107. fun `sends password first, when present`() = runBlocking {
  108. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  109. host = HOST
  110. port = PORT
  111. password = PASSWORD
  112. }, profileConfig, BehaviourConfig(), null))
  113. client.socketFactory = mockSocketFactory
  114. client.connect()
  115. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  116. assertEquals("PASS :$PASSWORD", String(sendLineChannel.receive()))
  117. }
  118. @Test
  119. fun `sends events to provided event handler`() {
  120. val client = IrcClientImpl(normalConfig)
  121. client.socketFactory = mockSocketFactory
  122. client.onEvent(mockEventHandler)
  123. GlobalScope.launch {
  124. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  125. }
  126. client.connect()
  127. verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
  128. }
  129. @Test
  130. fun `gets case mapping from server features`() {
  131. val client = IrcClientImpl(normalConfig)
  132. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  133. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  134. }
  135. @Test
  136. fun `indicates if user is local user or not`() {
  137. val client = IrcClientImpl(normalConfig)
  138. client.serverState.localNickname = "[acidBurn]"
  139. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  140. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  141. }
  142. @Test
  143. fun `indicates if nickname is local user or not`() {
  144. val client = IrcClientImpl(normalConfig)
  145. client.serverState.localNickname = "[acidBurn]"
  146. assertTrue(client.isLocalUser("{acidBurn}"))
  147. assertFalse(client.isLocalUser("acid-Burn"))
  148. }
  149. @Test
  150. fun `uses current case mapping to check local user`() {
  151. val client = IrcClientImpl(normalConfig)
  152. client.serverState.localNickname = "[acidBurn]"
  153. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  154. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  155. }
  156. @Test
  157. fun `sends text to socket`() = runBlocking {
  158. val client = IrcClientImpl(normalConfig)
  159. client.socketFactory = mockSocketFactory
  160. client.connect()
  161. client.send("testing 123")
  162. assertLineReceived("testing 123")
  163. }
  164. @Test
  165. fun `sends text to socket without label if cap is missing`() = runBlocking {
  166. val client = IrcClientImpl(normalConfig)
  167. client.socketFactory = mockSocketFactory
  168. client.connect()
  169. client.sendWithLabel("testing 123")
  170. assertLineReceived("testing 123")
  171. }
  172. @Test
  173. fun `sends text to socket with added tags and label`() = runBlocking {
  174. generateLabel = { "abc123" }
  175. val client = IrcClientImpl(normalConfig)
  176. client.socketFactory = mockSocketFactory
  177. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  178. client.connect()
  179. client.sendWithLabel("testing 123")
  180. assertLineReceived("@draft/label=abc123 testing 123")
  181. }
  182. @Test
  183. fun `sends tagged text to socket with label`() = runBlocking {
  184. generateLabel = { "abc123" }
  185. val client = IrcClientImpl(normalConfig)
  186. client.socketFactory = mockSocketFactory
  187. client.serverState.capabilities.enabledCapabilities[Capability.LabeledResponse] = ""
  188. client.connect()
  189. client.sendWithLabel("@+test=x testing 123")
  190. assertLineReceived("@draft/label=abc123;+test=x testing 123")
  191. }
  192. @Test
  193. fun `disconnects the socket`() = runBlocking {
  194. val client = IrcClientImpl(normalConfig)
  195. client.socketFactory = mockSocketFactory
  196. client.connect()
  197. client.disconnect()
  198. verify(mockSocket, timeout(500)).disconnect()
  199. }
  200. @Test
  201. @ObsoleteCoroutinesApi
  202. fun `sends messages in order`() = runBlocking {
  203. val client = IrcClientImpl(normalConfig)
  204. client.socketFactory = mockSocketFactory
  205. client.connect()
  206. (0..100).forEach { client.send("TEST $it") }
  207. assertEquals(100, withTimeoutOrNull(500) {
  208. var next = 0
  209. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  210. assertEquals("TEST $next", line)
  211. if (++next == 100) {
  212. break
  213. }
  214. }
  215. next
  216. })
  217. }
  218. @Test
  219. fun `defaults local nickname to profile`() {
  220. val client = IrcClientImpl(normalConfig)
  221. assertEquals(NICK, client.serverState.localNickname)
  222. }
  223. @Test
  224. fun `defaults server name to host name`() {
  225. val client = IrcClientImpl(normalConfig)
  226. assertEquals(HOST, client.serverState.serverName)
  227. }
  228. @Test
  229. fun `exposes behaviour config`() {
  230. val client = IrcClientImpl(IrcClientConfig(
  231. ServerConfig().apply { host = HOST },
  232. profileConfig,
  233. BehaviourConfig().apply { requestModesOnJoin = true },
  234. null))
  235. assertTrue(client.behaviour.requestModesOnJoin)
  236. }
  237. @Test
  238. fun `reset clears all state`() {
  239. with(IrcClientImpl(normalConfig)) {
  240. userState += User("acidBurn")
  241. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  242. serverState.serverName = "root.$HOST"
  243. reset()
  244. assertEquals(0, userState.count())
  245. assertEquals(0, channelState.count())
  246. assertEquals(HOST, serverState.serverName)
  247. }
  248. }
  249. @Test
  250. fun `sends connect error when host is unresolvable`() = runBlocking {
  251. whenever(mockSocket.connect()).doThrow(UnresolvedAddressException())
  252. with(IrcClientImpl(normalConfig)) {
  253. socketFactory = mockSocketFactory
  254. withTimeout(500) {
  255. launch {
  256. delay(50)
  257. connect()
  258. }
  259. val event = waitForEvent<ServerConnectionError>()
  260. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  261. }
  262. }
  263. }
  264. @Test
  265. fun `sends connect error when tls certificate is bad`() = runBlocking {
  266. whenever(mockSocket.connect()).doThrow(CertificateException("Boooo"))
  267. with(IrcClientImpl(normalConfig)) {
  268. socketFactory = mockSocketFactory
  269. withTimeout(500) {
  270. launch {
  271. delay(50)
  272. connect()
  273. }
  274. val event = waitForEvent<ServerConnectionError>()
  275. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  276. assertEquals("Boooo", event.details)
  277. }
  278. }
  279. }
  280. @Test
  281. fun `identifies channels that have a prefix in the chantypes feature`() {
  282. with(IrcClientImpl(normalConfig)) {
  283. serverState.features[ServerFeature.ChannelTypes] = "&~"
  284. assertTrue(isChannel("&dumpsterdiving"))
  285. assertTrue(isChannel("~hacktheplanet"))
  286. assertFalse(isChannel("#root"))
  287. assertFalse(isChannel("acidBurn"))
  288. assertFalse(isChannel(""))
  289. assertFalse(isChannel("acidBurn#~"))
  290. }
  291. }
  292. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  293. val mutex = Mutex(true)
  294. val value = AtomicReference<T>()
  295. onEvent {
  296. if (it is T) {
  297. value.set(it)
  298. mutex.unlock()
  299. }
  300. }
  301. mutex.lock()
  302. return value.get()
  303. }
  304. private suspend fun assertLineReceived(expected: String) {
  305. assertEquals(true, withTimeoutOrNull(500) {
  306. for (line in sendLineChannel.map { String(it) }) {
  307. println(line)
  308. if (line == expected) {
  309. return@withTimeoutOrNull true
  310. }
  311. }
  312. false
  313. }) { "Expected to receive $expected" }
  314. }
  315. }