You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bearer.go 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Copyright (c) 2024 Shivaram Lingamneni <slingamn@cs.stanford.edu>
  2. // released under the MIT license
  3. package jwt
  4. import (
  5. "fmt"
  6. "io"
  7. "os"
  8. "strings"
  9. jwt "github.com/golang-jwt/jwt/v5"
  10. )
  11. var (
  12. ErrAuthDisabled = fmt.Errorf("JWT authentication is disabled")
  13. ErrNoValidAccountClaim = fmt.Errorf("JWT token did not contain an acceptable account name claim")
  14. )
  15. // JWTAuthConfig is the config for Ergo to accept JWTs via draft/bearer
  16. type JWTAuthConfig struct {
  17. Enabled bool `yaml:"enabled"`
  18. Autocreate bool `yaml:"autocreate"`
  19. Tokens []JWTAuthTokenConfig `yaml:"tokens"`
  20. }
  21. type JWTAuthTokenConfig struct {
  22. Algorithm string `yaml:"algorithm"`
  23. KeyString string `yaml:"key"`
  24. KeyFile string `yaml:"key-file"`
  25. key any
  26. parser *jwt.Parser
  27. AccountClaims []string `yaml:"account-claims"`
  28. StripDomain string `yaml:"strip-domain"`
  29. }
  30. func (j *JWTAuthConfig) Postprocess() error {
  31. if !j.Enabled {
  32. return nil
  33. }
  34. if len(j.Tokens) == 0 {
  35. return fmt.Errorf("JWT authentication enabled, but no valid tokens defined")
  36. }
  37. for i := range j.Tokens {
  38. if err := j.Tokens[i].Postprocess(); err != nil {
  39. return err
  40. }
  41. }
  42. return nil
  43. }
  44. func (j *JWTAuthTokenConfig) Postprocess() error {
  45. keyBytes, err := j.keyBytes()
  46. if err != nil {
  47. return err
  48. }
  49. j.Algorithm = strings.ToLower(j.Algorithm)
  50. var methods []string
  51. switch j.Algorithm {
  52. case "hmac":
  53. j.key = keyBytes
  54. methods = []string{"HS256", "HS384", "HS512"}
  55. case "rsa":
  56. rsaKey, err := jwt.ParseRSAPublicKeyFromPEM(keyBytes)
  57. if err != nil {
  58. return err
  59. }
  60. j.key = rsaKey
  61. methods = []string{"RS256", "RS384", "RS512"}
  62. case "eddsa":
  63. eddsaKey, err := jwt.ParseEdPublicKeyFromPEM(keyBytes)
  64. if err != nil {
  65. return err
  66. }
  67. j.key = eddsaKey
  68. methods = []string{"EdDSA"}
  69. default:
  70. return fmt.Errorf("invalid jwt algorithm: %s", j.Algorithm)
  71. }
  72. j.parser = jwt.NewParser(jwt.WithValidMethods(methods))
  73. if len(j.AccountClaims) == 0 {
  74. return fmt.Errorf("JWT auth enabled, but no account-claims specified")
  75. }
  76. j.StripDomain = strings.ToLower(j.StripDomain)
  77. return nil
  78. }
  79. func (j *JWTAuthConfig) Validate(t string) (accountName string, err error) {
  80. if !j.Enabled || len(j.Tokens) == 0 {
  81. return "", ErrAuthDisabled
  82. }
  83. for i := range j.Tokens {
  84. accountName, err = j.Tokens[i].Validate(t)
  85. if err == nil {
  86. return
  87. }
  88. }
  89. return
  90. }
  91. func (j *JWTAuthTokenConfig) keyBytes() (result []byte, err error) {
  92. if j.KeyFile != "" {
  93. o, err := os.Open(j.KeyFile)
  94. if err != nil {
  95. return nil, err
  96. }
  97. defer o.Close()
  98. return io.ReadAll(o)
  99. }
  100. if j.KeyString != "" {
  101. return []byte(j.KeyString), nil
  102. }
  103. return nil, fmt.Errorf("JWT auth enabled, but no JWT key specified")
  104. }
  105. // implements jwt.Keyfunc
  106. func (j *JWTAuthTokenConfig) keyFunc(_ *jwt.Token) (interface{}, error) {
  107. return j.key, nil
  108. }
  109. func (j *JWTAuthTokenConfig) Validate(t string) (accountName string, err error) {
  110. token, err := j.parser.Parse(t, j.keyFunc)
  111. if err != nil {
  112. return "", err
  113. }
  114. claims, ok := token.Claims.(jwt.MapClaims)
  115. if !ok {
  116. // impossible with Parse (as opposed to ParseWithClaims)
  117. return "", fmt.Errorf("unexpected type from parsed token claims: %T", claims)
  118. }
  119. for _, c := range j.AccountClaims {
  120. if v, ok := claims[c]; ok {
  121. if vstr, ok := v.(string); ok {
  122. // validate and strip email addresses:
  123. if idx := strings.IndexByte(vstr, '@'); idx != -1 {
  124. suffix := vstr[idx+1:]
  125. vstr = vstr[:idx]
  126. if strings.ToLower(suffix) != j.StripDomain {
  127. continue
  128. }
  129. }
  130. return vstr, nil // success
  131. }
  132. }
  133. }
  134. return "", ErrNoValidAccountClaim
  135. }