You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bearer.go 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // Copyright (c) 2024 Shivaram Lingamneni <slingamn@cs.stanford.edu>
  2. // released under the MIT license
  3. package jwt
  4. import (
  5. "fmt"
  6. "io"
  7. "os"
  8. "strings"
  9. jwt "github.com/golang-jwt/jwt/v5"
  10. )
  11. var (
  12. ErrAuthDisabled = fmt.Errorf("JWT authentication is disabled")
  13. ErrNoValidAccountClaim = fmt.Errorf("JWT token did not contain an acceptable account name claim")
  14. )
  15. // JWTAuthConfig is the config for Ergo to accept JWTs via draft/bearer
  16. type JWTAuthConfig struct {
  17. Enabled bool `yaml:"enabled"`
  18. Autocreate bool `yaml:"autocreate"`
  19. Tokens []JWTAuthTokenConfig `yaml:"tokens"`
  20. }
  21. type JWTAuthTokenConfig struct {
  22. Algorithm string `yaml:"algorithm"`
  23. KeyString string `yaml:"key"`
  24. KeyFile string `yaml:"key-file"`
  25. key any
  26. parser *jwt.Parser
  27. AccountClaims []string `yaml:"account-claims"`
  28. StripDomain string `yaml:"strip-domain"`
  29. }
  30. func (j *JWTAuthConfig) Postprocess() error {
  31. if !j.Enabled {
  32. return nil
  33. }
  34. if len(j.Tokens) == 0 {
  35. return fmt.Errorf("JWT authentication enabled, but no valid tokens defined")
  36. }
  37. for i := range j.Tokens {
  38. if err := j.Tokens[i].Postprocess(); err != nil {
  39. return err
  40. }
  41. }
  42. return nil
  43. }
  44. func (j *JWTAuthTokenConfig) Postprocess() error {
  45. keyBytes, err := j.keyBytes()
  46. if err != nil {
  47. return err
  48. }
  49. j.Algorithm = strings.ToLower(j.Algorithm)
  50. var methods []string
  51. switch j.Algorithm {
  52. case "hmac":
  53. j.key = keyBytes
  54. methods = []string{"HS256", "HS384", "HS512"}
  55. case "rsa":
  56. rsaKey, err := jwt.ParseRSAPublicKeyFromPEM(keyBytes)
  57. if err != nil {
  58. return err
  59. }
  60. j.key = rsaKey
  61. methods = []string{"RS256", "RS384", "RS512"}
  62. case "eddsa":
  63. eddsaKey, err := jwt.ParseEdPublicKeyFromPEM(keyBytes)
  64. if err != nil {
  65. return err
  66. }
  67. j.key = eddsaKey
  68. methods = []string{"EdDSA"}
  69. default:
  70. return fmt.Errorf("invalid jwt algorithm: %s", j.Algorithm)
  71. }
  72. j.parser = jwt.NewParser(jwt.WithValidMethods(methods))
  73. if len(j.AccountClaims) == 0 {
  74. return fmt.Errorf("JWT auth enabled, but no account-claims specified")
  75. }
  76. j.StripDomain = strings.ToLower(j.StripDomain)
  77. return nil
  78. }
  79. func (j *JWTAuthConfig) Validate(t string) (accountName string, err error) {
  80. if !j.Enabled || len(j.Tokens) == 0 {
  81. return "", ErrAuthDisabled
  82. }
  83. for i := range j.Tokens {
  84. accountName, err = j.Tokens[i].Validate(t)
  85. if err == nil {
  86. return
  87. }
  88. }
  89. return
  90. }
  91. func (j *JWTAuthTokenConfig) keyBytes() (result []byte, err error) {
  92. if j.KeyFile != "" {
  93. o, err := os.Open(j.KeyFile)
  94. if err != nil {
  95. return nil, err
  96. }
  97. return io.ReadAll(o)
  98. }
  99. if j.KeyString != "" {
  100. return []byte(j.KeyString), nil
  101. }
  102. return nil, fmt.Errorf("JWT auth enabled, but no JWT key specified")
  103. }
  104. // implements jwt.Keyfunc
  105. func (j *JWTAuthTokenConfig) keyFunc(_ *jwt.Token) (interface{}, error) {
  106. return j.key, nil
  107. }
  108. func (j *JWTAuthTokenConfig) Validate(t string) (accountName string, err error) {
  109. token, err := j.parser.Parse(t, j.keyFunc)
  110. if err != nil {
  111. return "", err
  112. }
  113. claims, ok := token.Claims.(jwt.MapClaims)
  114. if !ok {
  115. // impossible with Parse (as opposed to ParseWithClaims)
  116. return "", fmt.Errorf("unexpected type from parsed token claims: %T", claims)
  117. }
  118. for _, c := range j.AccountClaims {
  119. if v, ok := claims[c]; ok {
  120. if vstr, ok := v.(string); ok {
  121. // validate and strip email addresses:
  122. if idx := strings.IndexByte(vstr, '@'); idx != -1 {
  123. suffix := vstr[idx+1:]
  124. vstr = vstr[:idx]
  125. if strings.ToLower(suffix) != j.StripDomain {
  126. continue
  127. }
  128. }
  129. return vstr, nil // success
  130. }
  131. }
  132. }
  133. return "", ErrNoValidAccountClaim
  134. }