123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- // Copyright 2018 by David A. Golden. All rights reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License"); you may
- // not use this file except in compliance with the License. You may obtain
- // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
-
- package scram
-
- import (
- "encoding/base64"
- "errors"
- "fmt"
- "strconv"
- "strings"
- )
-
- type c1Msg struct {
- gs2Header string
- authzID string
- username string
- nonce string
- c1b string
- }
-
- type c2Msg struct {
- cbind []byte
- nonce string
- proof []byte
- c2wop string
- }
-
- type s1Msg struct {
- nonce string
- salt []byte
- iters int
- }
-
- type s2Msg struct {
- verifier []byte
- err string
- }
-
- func parseField(s, k string) (string, error) {
- t := strings.TrimPrefix(s, k+"=")
- if t == s {
- return "", fmt.Errorf("error parsing '%s' for field '%s'", s, k)
- }
- return t, nil
- }
-
- func parseGS2Flag(s string) (string, error) {
- if s[0] == 'p' {
- return "", fmt.Errorf("channel binding requested but not supported")
- }
-
- if s == "n" || s == "y" {
- return s, nil
- }
-
- return "", fmt.Errorf("error parsing '%s' for gs2 flag", s)
- }
-
- func parseFieldBase64(s, k string) ([]byte, error) {
- raw, err := parseField(s, k)
- if err != nil {
- return nil, err
- }
-
- dec, err := base64.StdEncoding.DecodeString(raw)
- if err != nil {
- return nil, err
- }
-
- return dec, nil
- }
-
- func parseFieldInt(s, k string) (int, error) {
- raw, err := parseField(s, k)
- if err != nil {
- return 0, err
- }
-
- num, err := strconv.Atoi(raw)
- if err != nil {
- return 0, fmt.Errorf("error parsing field '%s': %v", k, err)
- }
-
- return num, nil
- }
-
- func parseClientFirst(c1 string) (msg c1Msg, err error) {
-
- fields := strings.Split(c1, ",")
- if len(fields) < 4 {
- err = errors.New("not enough fields in first server message")
- return
- }
-
- gs2flag, err := parseGS2Flag(fields[0])
- if err != nil {
- return
- }
-
- // 'a' field is optional
- if len(fields[1]) > 0 {
- msg.authzID, err = parseField(fields[1], "a")
- if err != nil {
- return
- }
- }
-
- // Recombine and save the gs2 header
- msg.gs2Header = gs2flag + "," + msg.authzID + ","
-
- // Check for unsupported extensions field "m".
- if strings.HasPrefix(fields[2], "m=") {
- err = errors.New("SCRAM message extensions are not supported")
- return
- }
-
- msg.username, err = parseField(fields[2], "n")
- if err != nil {
- return
- }
-
- msg.nonce, err = parseField(fields[3], "r")
- if err != nil {
- return
- }
-
- msg.c1b = strings.Join(fields[2:], ",")
-
- return
- }
-
- func parseClientFinal(c2 string) (msg c2Msg, err error) {
- fields := strings.Split(c2, ",")
- if len(fields) < 3 {
- err = errors.New("not enough fields in first server message")
- return
- }
-
- msg.cbind, err = parseFieldBase64(fields[0], "c")
- if err != nil {
- return
- }
-
- msg.nonce, err = parseField(fields[1], "r")
- if err != nil {
- return
- }
-
- // Extension fields may come between nonce and proof, so we
- // grab the *last* fields as proof.
- msg.proof, err = parseFieldBase64(fields[len(fields)-1], "p")
- if err != nil {
- return
- }
-
- msg.c2wop = c2[:strings.LastIndex(c2, ",")]
-
- return
- }
-
- func parseServerFirst(s1 string) (msg s1Msg, err error) {
-
- // Check for unsupported extensions field "m".
- if strings.HasPrefix(s1, "m=") {
- err = errors.New("SCRAM message extensions are not supported")
- return
- }
-
- fields := strings.Split(s1, ",")
- if len(fields) < 3 {
- err = errors.New("not enough fields in first server message")
- return
- }
-
- msg.nonce, err = parseField(fields[0], "r")
- if err != nil {
- return
- }
-
- msg.salt, err = parseFieldBase64(fields[1], "s")
- if err != nil {
- return
- }
-
- msg.iters, err = parseFieldInt(fields[2], "i")
-
- return
- }
-
- func parseServerFinal(s2 string) (msg s2Msg, err error) {
- fields := strings.Split(s2, ",")
-
- msg.verifier, err = parseFieldBase64(fields[0], "v")
- if err == nil {
- return
- }
-
- msg.err, err = parseField(fields[0], "e")
-
- return
- }
|