You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sasl.go 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package oauth2
  2. /*
  3. https://github.com/emersion/go-sasl/blob/e73c9f7bad438a9bf3f5b28e661b74d752ecafdd/oauthbearer.go
  4. Copyright 2019-2022 Simon Ser, Frode Aannevik, Max Mazurov
  5. Released under the MIT license
  6. */
  7. import (
  8. "bytes"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "strconv"
  13. "strings"
  14. )
  15. var (
  16. ErrUnexpectedClientResponse = errors.New("unexpected client response")
  17. )
  18. // The OAUTHBEARER mechanism name.
  19. const OAuthBearer = "OAUTHBEARER"
  20. type OAuthBearerError struct {
  21. Status string `json:"status"`
  22. Schemes string `json:"schemes"`
  23. Scope string `json:"scope"`
  24. }
  25. type OAuthBearerOptions struct {
  26. Username string `json:"username,omitempty"`
  27. Token string `json:"token,omitempty"`
  28. Host string `json:"host,omitempty"`
  29. Port int `json:"port,omitempty"`
  30. }
  31. func (err *OAuthBearerError) Error() string {
  32. return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
  33. }
  34. type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
  35. type OAuthBearerServer struct {
  36. done bool
  37. failErr error
  38. authenticate OAuthBearerAuthenticator
  39. }
  40. func (a *OAuthBearerServer) fail(descr string) ([]byte, bool, error) {
  41. blob, err := json.Marshal(OAuthBearerError{
  42. Status: "invalid_request",
  43. Schemes: "bearer",
  44. })
  45. if err != nil {
  46. panic(err) // wtf
  47. }
  48. a.failErr = errors.New(descr)
  49. return blob, false, nil
  50. }
  51. func (a *OAuthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
  52. // Per RFC, we cannot just send an error, we need to return JSON-structured
  53. // value as a challenge and then after getting dummy response from the
  54. // client stop the exchange.
  55. if a.failErr != nil {
  56. // Server libraries (go-smtp, go-imap) will not call Next on
  57. // protocol-specific SASL cancel response ('*'). However, GS2 (and
  58. // indirectly OAUTHBEARER) defines a protocol-independent way to do so
  59. // using 0x01.
  60. if len(response) != 1 && response[0] != 0x01 {
  61. return nil, true, errors.New("unexpected response")
  62. }
  63. return nil, true, a.failErr
  64. }
  65. if a.done {
  66. err = ErrUnexpectedClientResponse
  67. return
  68. }
  69. // Generate empty challenge.
  70. if response == nil {
  71. return []byte{}, false, nil
  72. }
  73. a.done = true
  74. // Cut n,a=username,\x01host=...\x01auth=...
  75. // into
  76. // n
  77. // a=username
  78. // \x01host=...\x01auth=...\x01\x01
  79. parts := bytes.SplitN(response, []byte{','}, 3)
  80. if len(parts) != 3 {
  81. return a.fail("Invalid response")
  82. }
  83. flag := parts[0]
  84. authzid := parts[1]
  85. if !bytes.Equal(flag, []byte{'n'}) {
  86. return a.fail("Invalid response, missing 'n' in gs2-cb-flag")
  87. }
  88. opts := OAuthBearerOptions{}
  89. if len(authzid) > 0 {
  90. if !bytes.HasPrefix(authzid, []byte("a=")) {
  91. return a.fail("Invalid response, missing 'a=' in gs2-authzid")
  92. }
  93. opts.Username = string(bytes.TrimPrefix(authzid, []byte("a=")))
  94. }
  95. // Cut \x01host=...\x01auth=...\x01\x01
  96. // into
  97. // *empty*
  98. // host=...
  99. // auth=...
  100. // *empty*
  101. //
  102. // Note that this code does not do a lot of checks to make sure the input
  103. // follows the exact format specified by RFC.
  104. params := bytes.Split(parts[2], []byte{0x01})
  105. for _, p := range params {
  106. // Skip empty fields (one at start and end).
  107. if len(p) == 0 {
  108. continue
  109. }
  110. pParts := bytes.SplitN(p, []byte{'='}, 2)
  111. if len(pParts) != 2 {
  112. return a.fail("Invalid response, missing '='")
  113. }
  114. switch string(pParts[0]) {
  115. case "host":
  116. opts.Host = string(pParts[1])
  117. case "port":
  118. port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
  119. if err != nil {
  120. return a.fail("Invalid response, malformed 'port' value")
  121. }
  122. opts.Port = int(port)
  123. case "auth":
  124. const prefix = "bearer "
  125. strValue := string(pParts[1])
  126. // Token type is case-insensitive.
  127. if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
  128. return a.fail("Unsupported token type")
  129. }
  130. opts.Token = strValue[len(prefix):]
  131. default:
  132. return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
  133. }
  134. }
  135. authzErr := a.authenticate(opts)
  136. if authzErr != nil {
  137. blob, err := json.Marshal(authzErr)
  138. if err != nil {
  139. panic(err) // wtf
  140. }
  141. a.failErr = authzErr
  142. return blob, false, nil
  143. }
  144. return nil, true, nil
  145. }
  146. func NewOAuthBearerServer(auth OAuthBearerAuthenticator) *OAuthBearerServer {
  147. return &OAuthBearerServer{
  148. authenticate: auth,
  149. }
  150. }