Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ assertjson.Has(t, data, func(json *assertjson.AssertJSON) {
```go
import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/muonsoft/api-testing/jwt"
)

assertjson.Has(t, data, func(json *assertjson.AssertJSON) {
Expand Down
2 changes: 1 addition & 1 deletion apitest/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/muonsoft/api-testing/jwt"

Check failure on line 8 in apitest/response_test.go

View workflow job for this annotation

GitHub Actions / test

File is not properly formatted (gci)
"github.com/muonsoft/api-testing/apitest"
"github.com/muonsoft/api-testing/assertjson"
"github.com/muonsoft/api-testing/internal/mock"
Expand Down
10 changes: 5 additions & 5 deletions assertions/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/muonsoft/api-testing/jwt"

Check failure on line 10 in assertions/jwt.go

View workflow job for this annotation

GitHub Actions / test

File is not properly formatted (gci)
"github.com/muonsoft/api-testing/assertjson"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -84,7 +84,7 @@
jsonAssert(assertjson.NewAssertJSON(
a.t,
a.messagePrefix+`is JWT with payload: `,
map[string]interface{}(a.token.Claims.(jwt.MapClaims)),
map[string]interface{}(a.token.Claims),
))

return a
Expand Down Expand Up @@ -188,7 +188,7 @@
func (a *JWTAssertion) assertStringField(title string, name string, expected string, msgAndArgs ...interface{}) *JWTAssertion {
a.t.Helper()

raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
return a.failOnMissingField(title, name, strconv.Quote(expected), msgAndArgs...)
}
Expand All @@ -208,7 +208,7 @@
func (a *JWTAssertion) assertStringsField(title string, name string, expected []string, msgAndArgs ...interface{}) *JWTAssertion {
a.t.Helper()

raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
return a.failOnMissingField(title, name, wrapArray(formatStrings(expected)), msgAndArgs...)
}
Expand All @@ -226,7 +226,7 @@
}

func (a *JWTAssertion) assertTimeField(title string, name string) *TimeAssertion {
raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
a.failOnMissingField(title, name, "")
return nil
Expand Down
6 changes: 3 additions & 3 deletions assertjson/assertjson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"time"

"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/muonsoft/api-testing/jwt"
"github.com/muonsoft/api-testing/assertjson"
"github.com/muonsoft/api-testing/internal/mock"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -2783,7 +2783,7 @@ func TestHas(t *testing.T) {
json.Node().IsJWT(getJWTSecret).WithExpiresAt()
},
wantMessages: []string{
`failed asserting that JSON node "" is JWT: token has invalid claims: invalid type for claim: exp is invalid`,
`is JWT with expires at ("exp") : number is expected`,
},
},
{
Expand Down Expand Up @@ -2820,7 +2820,7 @@ func TestHas(t *testing.T) {
json.Node().IsJWT(getJWTSecret).WithNotBefore()
},
wantMessages: []string{
`failed asserting that JSON node "" is JWT: token has invalid claims: invalid type for claim: nbf is invalid`,
`is JWT with not before ("nbf") : number is expected`,
},
},
{
Expand Down
10 changes: 5 additions & 5 deletions assertjson/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/muonsoft/api-testing/jwt"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -107,7 +107,7 @@ func (a *JWTAssertion) WithPayload(jsonAssert JSONAssertFunc) *JWTAssertion {
jsonAssert(&AssertJSON{
t: a.t,
message: a.message + `is JWT with payload: `,
data: map[string]interface{}(a.token.Claims.(jwt.MapClaims)),
data: map[string]interface{}(a.token.Claims),
})

return a
Expand Down Expand Up @@ -217,7 +217,7 @@ func (a *JWTAssertion) Assert(assertFunc func(tb testing.TB, token *jwt.Token))
func (a *JWTAssertion) assertStringField(title string, name string, expected string, msgAndArgs ...interface{}) *JWTAssertion {
a.t.Helper()

raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
return a.failOnMissingField(title, name, strconv.Quote(expected), msgAndArgs...)
}
Expand All @@ -237,7 +237,7 @@ func (a *JWTAssertion) assertStringField(title string, name string, expected str
func (a *JWTAssertion) assertStringsField(title string, name string, expected []string, msgAndArgs ...interface{}) *JWTAssertion {
a.t.Helper()

raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
return a.failOnMissingField(title, name, wrapArray(formatStrings(expected)), msgAndArgs...)
}
Expand All @@ -255,7 +255,7 @@ func (a *JWTAssertion) assertStringsField(title string, name string, expected []
}

func (a *JWTAssertion) assertTimeField(title string, name string) *TimeAssertion {
raw, exist := a.token.Claims.(jwt.MapClaims)[name]
raw, exist := a.token.Claims[name]
if !exist {
a.failOnMissingField(title, name, "")
return nil
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.16

require (
github.com/gofrs/uuid/v5 v5.3.2
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/json-iterator/go v1.1.12
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/stretchr/testify v1.10.0
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0=
github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down
11 changes: 11 additions & 0 deletions jwt/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package jwt

import "errors"

var (
ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
ErrSignatureInvalid = errors.New("signature is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
)
43 changes: 43 additions & 0 deletions jwt/hmac.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package jwt

import (
"crypto/hmac"
"crypto/sha256"
)

// SigningMethodHMAC implements HS256.
type SigningMethodHMAC struct {
Name string
}

var signingMethodHS256 = &SigningMethodHMAC{Name: "HS256"}

// SigningMethodHS256 is the HMAC-SHA256 signing method.
var SigningMethodHS256 SigningMethod = signingMethodHS256

func (m *SigningMethodHMAC) Alg() string {
return m.Name
}

func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error {
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}
hasher := hmac.New(sha256.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
}
return nil
}

func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) {
keyBytes, ok := key.([]byte)
if !ok {
return nil, ErrInvalidKeyType
}
hasher := hmac.New(sha256.New, keyBytes)
hasher.Write([]byte(signingString))
return hasher.Sum(nil), nil
}
5 changes: 5 additions & 0 deletions jwt/map_claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package jwt

// MapClaims is a claims type that uses map[string]interface{} for JSON decoding.
// Used as the default claims type for parsing and creating tokens.
type MapClaims map[string]interface{}
84 changes: 84 additions & 0 deletions jwt/parse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package jwt

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)

const tokenDelimiter = "."

// Parse parses and verifies the JWT and returns the token.
// Only HS256 signature verification is supported.
func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {

Check failure on line 14 in jwt/parse.go

View workflow job for this annotation

GitHub Actions / test

calculated cyclomatic complexity for function Parse is 11, max is 10 (cyclop)
parts, ok := splitToken(tokenString)
if !ok {
return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrTokenMalformed)
}

token := &Token{Raw: tokenString}

// Decode header
headerBytes, err := decodeSegment(parts[0])
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err)

Check failure on line 25 in jwt/parse.go

View workflow job for this annotation

GitHub Actions / test

non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint)
}
if err := json.Unmarshal(headerBytes, &token.Header); err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err)

Check failure on line 28 in jwt/parse.go

View workflow job for this annotation

GitHub Actions / test

non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint)
}

// Decode claims
claimBytes, err := decodeSegment(parts[1])
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err)

Check failure on line 34 in jwt/parse.go

View workflow job for this annotation

GitHub Actions / test

non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint)
}
token.Claims = MapClaims{}
if err := json.Unmarshal(claimBytes, &token.Claims); err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err)
}

// Resolve signing method from header
alg, _ := token.Header["alg"].(string)
if alg == "" {
return nil, fmt.Errorf("%w: signing method (alg) is unspecified", ErrTokenUnverifiable)
}
token.Method = &methodByAlg{alg: alg}

// Decode signature
token.Signature, err = decodeSegment(parts[2])
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err)
}

if keyFunc == nil {
return nil, fmt.Errorf("%w: no keyfunc was provided", ErrTokenUnverifiable)
}
key, err := keyFunc(token)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenUnverifiable, err)
}

signingString := strings.Join(parts[0:2], ".")
if err := token.Method.Verify(signingString, token.Signature, key); err != nil {
return nil, fmt.Errorf("%w: %v", ErrTokenSignatureInvalid, err)
}

token.Valid = true
return token, nil
}

func splitToken(s string) ([]string, bool) {
parts := strings.SplitN(s, tokenDelimiter, 4)
if len(parts) != 3 {
return nil, false
}
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
return nil, false
}
return parts, true
}

func decodeSegment(seg string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(seg)
}
28 changes: 28 additions & 0 deletions jwt/signing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package jwt

// SigningMethod is used to sign and verify tokens.
type SigningMethod interface {
Verify(signingString string, sig []byte, key interface{}) error
Sign(signingString string, key interface{}) ([]byte, error)
Alg() string
}

// methodByAlg holds algorithm name from token header; only HS256 is verified.
type methodByAlg struct {
alg string
}

func (m *methodByAlg) Alg() string {
return m.alg
}

func (m *methodByAlg) Verify(signingString string, sig []byte, key interface{}) error {
if m.alg != "HS256" {
return ErrTokenSignatureInvalid
}
return signingMethodHS256.Verify(signingString, sig, key)
}

func (m *methodByAlg) Sign(signingString string, key interface{}) ([]byte, error) {
return nil, ErrTokenUnverifiable
}
67 changes: 67 additions & 0 deletions jwt/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package jwt

import (
"encoding/base64"
"encoding/json"
)

// Keyfunc is used by Parse to supply the key for verification.
// The function receives the parsed but unverified Token (e.g. to read "alg" from header).
type Keyfunc func(*Token) (interface{}, error)

// Token represents a JWT.
type Token struct {
Raw string

Check failure on line 14 in jwt/token.go

View workflow job for this annotation

GitHub Actions / test

File is not properly formatted (gci)
Method SigningMethod
Header map[string]interface{}
Claims MapClaims
Signature []byte
Valid bool
}

// NewWithClaims creates a new Token with the given signing method and claims.
func NewWithClaims(method SigningMethod, claims MapClaims) *Token {
if claims == nil {
claims = MapClaims{}
}
return &Token{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
},
Claims: claims,
Method: method,
}
}

// SignedString signs the token and returns the full JWT string.
func (t *Token) SignedString(key interface{}) (string, error) {
sstr, err := t.SigningString()
if err != nil {
return "", err
}
sig, err := t.Method.Sign(sstr, key)
if err != nil {
return "", err
}
t.Signature = sig
return sstr + "." + t.EncodeSegment(sig), nil
}

// SigningString returns the base64url(header).base64url(claims) string.
func (t *Token) SigningString() (string, error) {
h, err := json.Marshal(t.Header)
if err != nil {
return "", err
}
c, err := json.Marshal(t.Claims)
if err != nil {
return "", err
}
return t.EncodeSegment(h) + "." + t.EncodeSegment(c), nil
}

// EncodeSegment encodes bytes to base64url without padding.
func (t *Token) EncodeSegment(seg []byte) string {
return base64.RawURLEncoding.EncodeToString(seg)
}
Loading