You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ScramMechanism.kt 7.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package com.dmdirc.ktirc.sasl
  2. import com.dmdirc.ktirc.IrcClient
  3. import com.dmdirc.ktirc.SaslConfig
  4. import com.dmdirc.ktirc.messages.sendAuthenticationMessage
  5. import com.dmdirc.ktirc.util.logger
  6. import java.security.MessageDigest
  7. import java.security.SecureRandom
  8. import javax.crypto.Mac
  9. import javax.crypto.spec.SecretKeySpec
  10. import kotlin.experimental.xor
  11. import kotlin.random.asKotlinRandom
  12. internal class ScramMechanism(private val algorithm: String, override val priority: Int, private val saslConfig: SaslConfig) : SaslMechanism {
  13. private val log by logger()
  14. override val ircName = "SCRAM-${algorithm.toUpperCase()}"
  15. override fun handleAuthenticationEvent(client: IrcClient, data: ByteArray?) {
  16. val state = client.scramState
  17. try {
  18. when (state.scramStage) {
  19. ScramStage.SendingFirstMessage -> client.sendFirstMessage(state)
  20. ScramStage.SendingSecondMessage -> client.sendSecondMessage(state, data.parse())
  21. ScramStage.Finishing -> client.validateAndFinish(state, data.parse())
  22. }
  23. } catch (ex: ScramException) {
  24. client.abortScram(ex.localizedMessage)
  25. }
  26. }
  27. private fun IrcClient.sendFirstMessage(state: ScramState) {
  28. state.scramStage = ScramStage.SendingSecondMessage
  29. sendScramMessage(
  30. "n,,", // No channel binding, no impersonation
  31. ScramMessageType.AuthName to saslConfig.username.escape(),
  32. ScramMessageType.Nonce to state.clientNonce)
  33. }
  34. private fun IrcClient.sendSecondMessage(state: ScramState, data: Map<ScramMessageType, String>) {
  35. if (ScramMessageType.FutureExtensions in data)
  36. throw ScramException("Unsupported extension received: ${data[ScramMessageType.FutureExtensions]}")
  37. if (ScramMessageType.Error in data)
  38. throw ScramException("Error received from server: ${data[ScramMessageType.Error]}")
  39. state.iterCount = data[ScramMessageType.IterationCount]?.toIntOrNull()
  40. ?: throw ScramException("No iteration count provided")
  41. state.salt = data[ScramMessageType.Salt]?.fromBase64() ?: throw ScramException("No salt provided")
  42. state.serverNonce = data[ScramMessageType.Nonce] ?: throw ScramException("No server salt provided")
  43. state.saltedPassword = pbkdf2(saslConfig.password.toByteArray(), state.salt, state.iterCount)
  44. val clientKey = hmac(state.saltedPassword, "Client Key".toByteArray())
  45. val storedKey = hash(clientKey)
  46. state.authMessage = buildScramMessage(
  47. ScramMessageType.AuthName to saslConfig.username.escape(),
  48. ScramMessageType.Nonce to state.clientNonce,
  49. ScramMessageType.Nonce to state.serverNonce,
  50. ScramMessageType.Salt to state.salt.toBase64(),
  51. ScramMessageType.IterationCount to state.iterCount.toString(),
  52. ScramMessageType.ChannelBinding to "n,,".toByteArray().toBase64(),
  53. ScramMessageType.Nonce to state.serverNonce).toByteArray()
  54. val clientSignature = hmac(storedKey, state.authMessage)
  55. val clientProof = clientKey.xor(clientSignature)
  56. state.scramStage = ScramStage.Finishing
  57. sendScramMessage(
  58. "",
  59. ScramMessageType.ChannelBinding to "n,,".toByteArray().toBase64(),
  60. ScramMessageType.Nonce to state.serverNonce,
  61. ScramMessageType.ClientProof to clientProof.toBase64()
  62. )
  63. }
  64. private fun IrcClient.validateAndFinish(state: ScramState, data: Map<ScramMessageType, String>) {
  65. if (ScramMessageType.FutureExtensions in data)
  66. throw ScramException("Unsupported extension received: ${data[ScramMessageType.FutureExtensions]}")
  67. if (ScramMessageType.Error in data)
  68. throw ScramException("Error received from server: ${data[ScramMessageType.Error]}")
  69. val serverKey = hmac(state.saltedPassword, "Server Key".toByteArray())
  70. val expectedServerSignature = hmac(serverKey, state.authMessage).toBase64()
  71. val receivedServerSignature = data[ScramMessageType.ServerVerifier]
  72. ?: throw ScramException("No server verifier received")
  73. if (expectedServerSignature != receivedServerSignature) {
  74. throw ScramException("Server signature does not match")
  75. }
  76. sendAuthenticationMessage("+")
  77. }
  78. private fun IrcClient.abortScram(reason: String) {
  79. log.warning { "Aborting SCRAM authentication: $reason" }
  80. sendAuthenticationMessage("*")
  81. }
  82. private fun IrcClient.sendScramMessage(prefix: String = "", vararg entries: Pair<ScramMessageType, String>) =
  83. sendAuthenticationData("$prefix${buildScramMessage(*entries)}")
  84. private fun buildScramMessage(vararg entries: Pair<ScramMessageType, String>) = entries.joinToString(",") { (k, v) -> "${k.prefix}=$v" }
  85. private fun ByteArray?.parse(): Map<ScramMessageType, String> {
  86. return if (this == null || this.isEmpty())
  87. emptyMap()
  88. else
  89. String(this).split(',').map {
  90. getMessageType(it[0]) to it.substring(2).unescape()
  91. }.toMap()
  92. }
  93. private fun String.escape() = replace("=", "=3D").replace(",", "=2C")
  94. private fun String.unescape() = replace("=2C", ",").replace("=3D", "=")
  95. private fun hmac(keyMaterial: ByteArray, input: ByteArray): ByteArray {
  96. return with(Mac.getInstance("hmac${algorithm.replace("-", "")}")) {
  97. init(SecretKeySpec(keyMaterial, algorithm))
  98. doFinal(input)
  99. }
  100. }
  101. private fun hash(input: ByteArray): ByteArray {
  102. return with(MessageDigest.getInstance(algorithm.replace("-", ""))) {
  103. digest(input)
  104. }
  105. }
  106. private fun pbkdf2(keyMaterial: ByteArray, initialSalt: ByteArray, iterations: Int): ByteArray {
  107. var salt = initialSalt + 0x00 + 0x00 + 0x00 + 0x01
  108. var result: ByteArray? = null
  109. repeat(iterations) {
  110. salt = hmac(keyMaterial, salt)
  111. result = result?.xor(salt) ?: salt
  112. }
  113. return result ?: ByteArray(0)
  114. }
  115. private val IrcClient.scramState: ScramState
  116. get() = with(serverState.sasl) {
  117. (mechanismState as? ScramState ?: ScramState()).apply {
  118. mechanismState = this
  119. }
  120. }
  121. private fun ByteArray.xor(other: ByteArray): ByteArray = zip(other) { a, b -> a.xor(b) }.toByteArray()
  122. }
  123. private class ScramException(message: String) : RuntimeException(message)
  124. private fun newNonce(): String {
  125. val charPool: List<Char> = (' '..'~') - ',' - '='
  126. val random = SecureRandom().asKotlinRandom()
  127. return (0..31).map { charPool.random(random) }.joinToString("")
  128. }
  129. internal class ScramState(
  130. var scramStage: ScramStage = ScramStage.SendingFirstMessage,
  131. val clientNonce: String = newNonce(),
  132. var serverNonce: String = "",
  133. var iterCount: Int = 1,
  134. var salt: ByteArray = ByteArray(0),
  135. var saltedPassword: ByteArray = ByteArray(0),
  136. var authMessage: ByteArray = ByteArray(0))
  137. internal enum class ScramStage {
  138. SendingFirstMessage,
  139. SendingSecondMessage,
  140. Finishing
  141. }
  142. internal enum class ScramMessageType(val prefix: Char) {
  143. AuthName('n'),
  144. FutureExtensions('m'),
  145. Nonce('r'),
  146. ChannelBinding('c'),
  147. Salt('s'),
  148. IterationCount('i'),
  149. ClientProof('p'),
  150. ServerVerifier('v'),
  151. Error('e'),
  152. }
  153. private fun getMessageType(prefix: Char) =
  154. ScramMessageType.values().firstOrNull { it.prefix == prefix } ?: ScramMessageType.Error