You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ScramMechanism.kt 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package com.dmdirc.ktirc.sasl
  2. import com.dmdirc.ktirc.IrcClient
  3. import com.dmdirc.ktirc.SaslConfig
  4. import com.dmdirc.ktirc.messages.sendAuthenticationMessage
  5. import com.dmdirc.ktirc.util.logger
  6. import java.security.MessageDigest
  7. import java.security.SecureRandom
  8. import javax.crypto.Mac
  9. import javax.crypto.spec.SecretKeySpec
  10. import kotlin.experimental.xor
  11. import kotlin.random.asKotlinRandom
  12. internal class ScramMechanism(private val algorithm: String, override val priority: Int, private val saslConfig: SaslConfig) : SaslMechanism {
  13. private val log by logger()
  14. override val ircName = "SCRAM-${algorithm.toUpperCase()}"
  15. override fun handleAuthenticationEvent(client: IrcClient, data: ByteArray?) {
  16. val state = client.scramState
  17. try {
  18. when (state.scramStage) {
  19. ScramStage.SendingFirstMessage -> client.sendFirstMessage(state)
  20. ScramStage.SendingSecondMessage -> client.sendSecondMessage(state, data.parse())
  21. ScramStage.Finishing -> client.validateAndFinish(state, data.parse())
  22. }
  23. } catch (ex: ScramException) {
  24. client.abortScram(ex.localizedMessage)
  25. }
  26. }
  27. private fun IrcClient.sendFirstMessage(state: ScramState) {
  28. state.scramStage = ScramStage.SendingSecondMessage
  29. sendScramMessage(
  30. "n,,", // No channel binding, no impersonation
  31. ScramMessageType.AuthName to saslConfig.username.escape(),
  32. ScramMessageType.Nonce to state.clientNonce)
  33. }
  34. private fun IrcClient.sendSecondMessage(state: ScramState, data: Map<ScramMessageType, String>) {
  35. if (ScramMessageType.FutureExtensions in data)
  36. throw ScramException("Unsupported extension received: ${data[ScramMessageType.FutureExtensions]}")
  37. if (ScramMessageType.Error in data)
  38. throw ScramException("Error received from server: ${data[ScramMessageType.Error]}")
  39. state.iterCount = data[ScramMessageType.IterationCount]?.toIntOrNull()
  40. ?: throw ScramException("No iteration count provided")
  41. state.salt = data[ScramMessageType.Salt]?.fromBase64() ?: throw ScramException("No salt provided")
  42. state.serverNonce = data[ScramMessageType.Nonce] ?: throw ScramException("No server salt provided")
  43. state.saltedPassword = pbkdf2(saslConfig.password.toByteArray(), state.salt, state.iterCount)
  44. val clientKey = hmac(state.saltedPassword, "Client Key".toByteArray())
  45. val storedKey = hash(clientKey)
  46. state.authMessage = buildScramMessage(
  47. ScramMessageType.AuthName to saslConfig.username.escape(),
  48. ScramMessageType.Nonce to state.clientNonce,
  49. ScramMessageType.Nonce to state.serverNonce,
  50. ScramMessageType.Salt to state.salt.toBase64(),
  51. ScramMessageType.IterationCount to state.iterCount.toString(),
  52. ScramMessageType.ChannelBinding to "n,,".toByteArray().toBase64(),
  53. ScramMessageType.Nonce to state.serverNonce).toByteArray()
  54. val clientSignature = hmac(storedKey, state.authMessage)
  55. val clientProof = clientKey.xor(clientSignature)
  56. state.scramStage = ScramStage.Finishing
  57. sendScramMessage(
  58. "",
  59. ScramMessageType.ChannelBinding to "n,,".toByteArray().toBase64(),
  60. ScramMessageType.Nonce to state.serverNonce,
  61. ScramMessageType.ClientProof to clientProof.toBase64()
  62. )
  63. }
  64. private fun IrcClient.validateAndFinish(state: ScramState, data: Map<ScramMessageType, String>) {
  65. if (ScramMessageType.FutureExtensions in data)
  66. throw ScramException("Unsupported extension received: ${data[ScramMessageType.FutureExtensions]}")
  67. if (ScramMessageType.Error in data)
  68. throw ScramException("Error received from server: ${data[ScramMessageType.Error]}")
  69. val serverKey = hmac(state.saltedPassword, "Server Key".toByteArray())
  70. val expectedServerSignature = hmac(serverKey, state.authMessage).toBase64()
  71. val receivedServerSignature = data[ScramMessageType.ServerVerifier]
  72. ?: throw ScramException("No server verifier received")
  73. if (expectedServerSignature != receivedServerSignature) {
  74. throw ScramException("Server signature does not match")
  75. }
  76. sendAuthenticationMessage("+")
  77. }
  78. private fun IrcClient.abortScram(reason: String) {
  79. log.warning { "Aborting SCRAM authentication: $reason" }
  80. sendAuthenticationMessage("*")
  81. }
  82. private fun IrcClient.sendScramMessage(prefix: String = "", vararg entries: Pair<ScramMessageType, String>)
  83. = sendAuthenticationData("$prefix${buildScramMessage(*entries)}")
  84. private fun buildScramMessage(vararg entries: Pair<ScramMessageType, String>)
  85. = entries.joinToString(",") { (k, v) -> "${k.prefix}=$v" }
  86. private fun ByteArray?.parse() = if (this == null || this.isEmpty())
  87. emptyMap()
  88. else
  89. String(this).split(',').map {
  90. getMessageType(it[0]) to it.substring(2).unescape()
  91. }.toMap()
  92. private fun String.escape() = replace("=", "=3D").replace(",", "=2C")
  93. private fun String.unescape() = replace("=2C", ",").replace("=3D", "=")
  94. private fun hmac(keyMaterial: ByteArray, input: ByteArray) = with(Mac.getInstance("hmac${algorithm.replace("-", "")}")) {
  95. init(SecretKeySpec(keyMaterial, algorithm))
  96. doFinal(input)
  97. }
  98. private fun hash(input: ByteArray) = with(MessageDigest.getInstance(algorithm.replace("-", ""))) {
  99. digest(input)
  100. }
  101. private fun pbkdf2(keyMaterial: ByteArray, initialSalt: ByteArray, iterations: Int): ByteArray {
  102. var salt = initialSalt + 0x00 + 0x00 + 0x00 + 0x01
  103. var result: ByteArray? = null
  104. repeat(iterations) {
  105. salt = hmac(keyMaterial, salt)
  106. result = result?.xor(salt) ?: salt
  107. }
  108. return result ?: ByteArray(0)
  109. }
  110. private val IrcClient.scramState: ScramState
  111. get() = with(serverState.sasl) {
  112. (mechanismState as? ScramState ?: com.dmdirc.ktirc.sasl.ScramState()).apply {
  113. mechanismState = this
  114. }
  115. }
  116. private fun ByteArray.xor(other: ByteArray): ByteArray = zip(other) { a, b -> a.xor(b) }.toByteArray()
  117. }
  118. private class ScramException(message: String) : RuntimeException(message)
  119. private fun newNonce(): String {
  120. val charPool: List<Char> = (' '..'~') - ',' - '='
  121. val random = SecureRandom.getInstanceStrong().asKotlinRandom()
  122. return CharArray(32) { charPool.random(random) }.joinToString("")
  123. }
  124. internal class ScramState(
  125. var scramStage: ScramStage = ScramStage.SendingFirstMessage,
  126. val clientNonce: String = newNonce(),
  127. var serverNonce: String = "",
  128. var iterCount: Int = 1,
  129. var salt: ByteArray = ByteArray(0),
  130. var saltedPassword: ByteArray = ByteArray(0),
  131. var authMessage: ByteArray = ByteArray(0))
  132. internal enum class ScramStage {
  133. SendingFirstMessage,
  134. SendingSecondMessage,
  135. Finishing
  136. }
  137. internal enum class ScramMessageType(val prefix: Char) {
  138. AuthName('n'),
  139. FutureExtensions('m'),
  140. Nonce('r'),
  141. ChannelBinding('c'),
  142. Salt('s'),
  143. IterationCount('i'),
  144. ClientProof('p'),
  145. ServerVerifier('v'),
  146. Error('e'),
  147. }
  148. private fun getMessageType(prefix: Char) =
  149. ScramMessageType.values().firstOrNull { it.prefix == prefix } ?: ScramMessageType.Error