// Copyright 2018 by David A. Golden. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package scram import ( "encoding/base64" "errors" "fmt" "strconv" "strings" ) type c1Msg struct { gs2Header string authzID string username string nonce string c1b string } type c2Msg struct { cbind []byte nonce string proof []byte c2wop string } type s1Msg struct { nonce string salt []byte iters int } type s2Msg struct { verifier []byte err string } func parseField(s, k string) (string, error) { t := strings.TrimPrefix(s, k+"=") if t == s { return "", fmt.Errorf("error parsing '%s' for field '%s'", s, k) } return t, nil } func parseGS2Flag(s string) (string, error) { if s[0] == 'p' { return "", fmt.Errorf("channel binding requested but not supported") } if s == "n" || s == "y" { return s, nil } return "", fmt.Errorf("error parsing '%s' for gs2 flag", s) } func parseFieldBase64(s, k string) ([]byte, error) { raw, err := parseField(s, k) if err != nil { return nil, err } dec, err := base64.StdEncoding.DecodeString(raw) if err != nil { return nil, err } return dec, nil } func parseFieldInt(s, k string) (int, error) { raw, err := parseField(s, k) if err != nil { return 0, err } num, err := strconv.Atoi(raw) if err != nil { return 0, fmt.Errorf("error parsing field '%s': %v", k, err) } return num, nil } func parseClientFirst(c1 string) (msg c1Msg, err error) { fields := strings.Split(c1, ",") if len(fields) < 4 { err = errors.New("not enough fields in first server message") return } gs2flag, err := parseGS2Flag(fields[0]) if err != nil { return } // 'a' field is optional if len(fields[1]) > 0 { msg.authzID, err = parseField(fields[1], "a") if err != nil { return } } // Recombine and save the gs2 header msg.gs2Header = gs2flag + "," + msg.authzID + "," // Check for unsupported extensions field "m". if strings.HasPrefix(fields[2], "m=") { err = errors.New("SCRAM message extensions are not supported") return } msg.username, err = parseField(fields[2], "n") if err != nil { return } msg.nonce, err = parseField(fields[3], "r") if err != nil { return } msg.c1b = strings.Join(fields[2:], ",") return } func parseClientFinal(c2 string) (msg c2Msg, err error) { fields := strings.Split(c2, ",") if len(fields) < 3 { err = errors.New("not enough fields in first server message") return } msg.cbind, err = parseFieldBase64(fields[0], "c") if err != nil { return } msg.nonce, err = parseField(fields[1], "r") if err != nil { return } // Extension fields may come between nonce and proof, so we // grab the *last* fields as proof. msg.proof, err = parseFieldBase64(fields[len(fields)-1], "p") if err != nil { return } msg.c2wop = c2[:strings.LastIndex(c2, ",")] return } func parseServerFirst(s1 string) (msg s1Msg, err error) { // Check for unsupported extensions field "m". if strings.HasPrefix(s1, "m=") { err = errors.New("SCRAM message extensions are not supported") return } fields := strings.Split(s1, ",") if len(fields) < 3 { err = errors.New("not enough fields in first server message") return } msg.nonce, err = parseField(fields[0], "r") if err != nil { return } msg.salt, err = parseFieldBase64(fields[1], "s") if err != nil { return } msg.iters, err = parseFieldInt(fields[2], "i") return } func parseServerFinal(s2 string) (msg s2Msg, err error) { fields := strings.Split(s2, ",") msg.verifier, err = parseFieldBase64(fields[0], "v") if err == nil { return } msg.err, err = parseField(fields[0], "e") return }