You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

IrcClientImplTest.kt 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. package com.dmdirc.ktirc
  2. import com.dmdirc.ktirc.events.*
  3. import com.dmdirc.ktirc.io.CaseMapping
  4. import com.dmdirc.ktirc.io.LineBufferedSocket
  5. import com.dmdirc.ktirc.model.ChannelState
  6. import com.dmdirc.ktirc.model.ConnectionError
  7. import com.dmdirc.ktirc.model.ServerFeature
  8. import com.dmdirc.ktirc.model.User
  9. import com.dmdirc.ktirc.util.currentTimeProvider
  10. import com.nhaarman.mockitokotlin2.*
  11. import io.ktor.util.KtorExperimentalAPI
  12. import kotlinx.coroutines.*
  13. import kotlinx.coroutines.channels.Channel
  14. import kotlinx.coroutines.channels.filter
  15. import kotlinx.coroutines.channels.map
  16. import kotlinx.coroutines.sync.Mutex
  17. import org.junit.jupiter.api.Assertions.*
  18. import org.junit.jupiter.api.BeforeEach
  19. import org.junit.jupiter.api.Test
  20. import org.junit.jupiter.api.assertThrows
  21. import java.nio.channels.UnresolvedAddressException
  22. import java.security.cert.CertificateException
  23. import java.util.concurrent.atomic.AtomicReference
  24. @KtorExperimentalAPI
  25. @ExperimentalCoroutinesApi
  26. internal class IrcClientImplTest {
  27. companion object {
  28. private const val HOST = "thegibson.com"
  29. private const val PORT = 12345
  30. private const val NICK = "AcidBurn"
  31. private const val REAL_NAME = "Kate Libby"
  32. private const val USER_NAME = "acidb"
  33. private const val PASSWORD = "HackThePlanet"
  34. }
  35. private val readLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  36. private val sendLineChannel = Channel<ByteArray>(Channel.UNLIMITED)
  37. private val mockSocket = mock<LineBufferedSocket> {
  38. on { receiveChannel } doReturn readLineChannel
  39. on { sendChannel } doReturn sendLineChannel
  40. }
  41. private val mockSocketFactory = mock<(CoroutineScope, String, Int, Boolean) -> LineBufferedSocket> {
  42. on { invoke(any(), eq(HOST), eq(PORT), any()) } doReturn mockSocket
  43. }
  44. private val mockEventHandler = mock<(IrcEvent) -> Unit>()
  45. private val profileConfig = ProfileConfig().apply {
  46. nickname = NICK
  47. realName = REAL_NAME
  48. username = USER_NAME
  49. }
  50. private val normalConfig = IrcClientConfig(ServerConfig().apply {
  51. host = HOST
  52. port = PORT
  53. }, profileConfig, BehaviourConfig(), null)
  54. @BeforeEach
  55. fun setUp() {
  56. currentTimeProvider = { TestConstants.time }
  57. }
  58. @Test
  59. fun `uses socket factory to create a new socket on connect`() {
  60. val client = IrcClientImpl(normalConfig)
  61. client.socketFactory = mockSocketFactory
  62. client.connect()
  63. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, false)
  64. }
  65. @Test
  66. fun `uses socket factory to create a new tls on connect`() {
  67. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  68. host = HOST
  69. port = PORT
  70. useTls = true
  71. }, profileConfig, BehaviourConfig(), null))
  72. client.socketFactory = mockSocketFactory
  73. client.connect()
  74. verify(mockSocketFactory, timeout(500)).invoke(client, HOST, PORT, true)
  75. }
  76. @Test
  77. fun `throws if socket already exists`() {
  78. val client = IrcClientImpl(normalConfig)
  79. client.socketFactory = mockSocketFactory
  80. client.connect()
  81. assertThrows<IllegalStateException> {
  82. client.connect()
  83. }
  84. }
  85. @Test
  86. fun `emits connection events with local time`() = runBlocking {
  87. currentTimeProvider = { TestConstants.time }
  88. val client = IrcClientImpl(normalConfig)
  89. client.socketFactory = mockSocketFactory
  90. client.onEvent(mockEventHandler)
  91. client.connect()
  92. val captor = argumentCaptor<IrcEvent>()
  93. verify(mockEventHandler, timeout(500).atLeast(2)).invoke(captor.capture())
  94. assertTrue(captor.firstValue is ServerConnecting)
  95. assertEquals(TestConstants.time, captor.firstValue.time)
  96. assertTrue(captor.secondValue is ServerConnected)
  97. assertEquals(TestConstants.time, captor.secondValue.time)
  98. }
  99. @Test
  100. fun `sends basic connection strings`() = runBlocking {
  101. val client = IrcClientImpl(normalConfig)
  102. client.socketFactory = mockSocketFactory
  103. client.connect()
  104. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  105. assertEquals("NICK :$NICK", String(sendLineChannel.receive()))
  106. assertEquals("USER $USER_NAME 0 * :$REAL_NAME", String(sendLineChannel.receive()))
  107. }
  108. @Test
  109. fun `sends password first, when present`() = runBlocking {
  110. val client = IrcClientImpl(IrcClientConfig(ServerConfig().apply {
  111. host = HOST
  112. port = PORT
  113. password = PASSWORD
  114. }, profileConfig, BehaviourConfig(), null))
  115. client.socketFactory = mockSocketFactory
  116. client.connect()
  117. assertEquals("CAP LS 302", String(sendLineChannel.receive()))
  118. assertEquals("PASS :$PASSWORD", String(sendLineChannel.receive()))
  119. }
  120. @Test
  121. fun `sends events to provided event handler`() {
  122. val client = IrcClientImpl(normalConfig)
  123. client.socketFactory = mockSocketFactory
  124. client.onEvent(mockEventHandler)
  125. GlobalScope.launch {
  126. readLineChannel.send(":the.gibson 001 acidBurn :Welcome to the IRC!".toByteArray())
  127. }
  128. client.connect()
  129. verify(mockEventHandler, timeout(500)).invoke(isA<ServerWelcome>())
  130. }
  131. @Test
  132. fun `gets case mapping from server features`() {
  133. val client = IrcClientImpl(normalConfig)
  134. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.RfcStrict
  135. assertEquals(CaseMapping.RfcStrict, client.caseMapping)
  136. }
  137. @Test
  138. fun `indicates if user is local user or not`() {
  139. val client = IrcClientImpl(normalConfig)
  140. client.serverState.localNickname = "[acidBurn]"
  141. assertTrue(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  142. assertFalse(client.isLocalUser(User("acid-Burn", "libby", "root.localhost")))
  143. }
  144. @Test
  145. fun `indicates if nickname is local user or not`() {
  146. val client = IrcClientImpl(normalConfig)
  147. client.serverState.localNickname = "[acidBurn]"
  148. assertTrue(client.isLocalUser("{acidBurn}"))
  149. assertFalse(client.isLocalUser("acid-Burn"))
  150. }
  151. @Test
  152. fun `uses current case mapping to check local user`() {
  153. val client = IrcClientImpl(normalConfig)
  154. client.serverState.localNickname = "[acidBurn]"
  155. client.serverState.features[ServerFeature.ServerCaseMapping] = CaseMapping.Ascii
  156. assertFalse(client.isLocalUser(User("{acidBurn}", "libby", "root.localhost")))
  157. }
  158. @Test
  159. fun `sends text to socket`() = runBlocking {
  160. val client = IrcClientImpl(normalConfig)
  161. client.socketFactory = mockSocketFactory
  162. client.connect()
  163. client.send("testing 123")
  164. assertEquals(true, withTimeoutOrNull(500) {
  165. var found = false
  166. for (line in sendLineChannel) {
  167. if (String(line) == "testing 123") {
  168. found = true
  169. break
  170. }
  171. }
  172. found
  173. })
  174. }
  175. @Test
  176. fun `disconnects the socket`() = runBlocking {
  177. val client = IrcClientImpl(normalConfig)
  178. client.socketFactory = mockSocketFactory
  179. client.connect()
  180. client.disconnect()
  181. verify(mockSocket, timeout(500)).disconnect()
  182. }
  183. @Test
  184. @ObsoleteCoroutinesApi
  185. fun `sends messages in order`() = runBlocking {
  186. val client = IrcClientImpl(normalConfig)
  187. client.socketFactory = mockSocketFactory
  188. client.connect()
  189. (0..100).forEach { client.send("TEST $it") }
  190. assertEquals(100, withTimeoutOrNull(500) {
  191. var next = 0
  192. for (line in sendLineChannel.map { String(it) }.filter { it.startsWith("TEST ") }) {
  193. assertEquals("TEST $next", line)
  194. if (++next == 100) {
  195. break
  196. }
  197. }
  198. next
  199. })
  200. }
  201. @Test
  202. fun `defaults local nickname to profile`() {
  203. val client = IrcClientImpl(normalConfig)
  204. assertEquals(NICK, client.serverState.localNickname)
  205. }
  206. @Test
  207. fun `defaults server name to host name`() {
  208. val client = IrcClientImpl(normalConfig)
  209. assertEquals(HOST, client.serverState.serverName)
  210. }
  211. @Test
  212. fun `exposes behaviour config`() {
  213. val client = IrcClientImpl(IrcClientConfig(
  214. ServerConfig().apply { host = HOST },
  215. profileConfig,
  216. BehaviourConfig().apply { requestModesOnJoin = true },
  217. null))
  218. assertTrue(client.behaviour.requestModesOnJoin)
  219. }
  220. @Test
  221. fun `reset clears all state`() {
  222. with(IrcClientImpl(normalConfig)) {
  223. userState += User("acidBurn")
  224. channelState += ChannelState("#thegibson") { CaseMapping.Rfc }
  225. serverState.serverName = "root.$HOST"
  226. reset()
  227. assertEquals(0, userState.count())
  228. assertEquals(0, channelState.count())
  229. assertEquals(HOST, serverState.serverName)
  230. }
  231. }
  232. @Test
  233. fun `sends connect error when host is unresolvable`() = runBlocking {
  234. whenever(mockSocket.connect()).doThrow(UnresolvedAddressException())
  235. with(IrcClientImpl(normalConfig)) {
  236. socketFactory = mockSocketFactory
  237. withTimeout(500) {
  238. launch {
  239. delay(50)
  240. connect()
  241. }
  242. val event = waitForEvent<ServerConnectionError>()
  243. assertEquals(ConnectionError.UnresolvableAddress, event.error)
  244. }
  245. }
  246. }
  247. @Test
  248. fun `sends connect error when tls certificate is bad`() = runBlocking {
  249. whenever(mockSocket.connect()).doThrow(CertificateException("Boooo"))
  250. with(IrcClientImpl(normalConfig)) {
  251. socketFactory = mockSocketFactory
  252. withTimeout(500) {
  253. launch {
  254. delay(50)
  255. connect()
  256. }
  257. val event = waitForEvent<ServerConnectionError>()
  258. assertEquals(ConnectionError.BadTlsCertificate, event.error)
  259. assertEquals("Boooo", event.details)
  260. }
  261. }
  262. }
  263. @Test
  264. fun `identifies channels that have a prefix in the chantypes feature`() {
  265. with(IrcClientImpl(normalConfig)) {
  266. serverState.features[ServerFeature.ChannelTypes] = "&~"
  267. assertTrue(isChannel("&dumpsterdiving"))
  268. assertTrue(isChannel("~hacktheplanet"))
  269. assertFalse(isChannel("#root"))
  270. assertFalse(isChannel("acidBurn"))
  271. assertFalse(isChannel(""))
  272. assertFalse(isChannel("acidBurn#~"))
  273. }
  274. }
  275. private suspend inline fun <reified T : IrcEvent> IrcClient.waitForEvent(): T {
  276. val mutex = Mutex(true)
  277. val value = AtomicReference<T>()
  278. onEvent {
  279. if (it is T) {
  280. value.set(it)
  281. mutex.unlock()
  282. }
  283. }
  284. mutex.lock()
  285. return value.get()
  286. }
  287. }