Created
September 14, 2023 13:43
-
-
Save andreyvit/d6ebd84cf1086efddffe2001f9eb4719 to your computer and use it in GitHub Desktop.
Example simple JWT implementation for Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jwt | |
import ( | |
"bytes" | |
"crypto/hmac" | |
"crypto/sha256" | |
"crypto/subtle" | |
"encoding/base64" | |
"encoding/json" | |
"errors" | |
"time" | |
"unsafe" | |
) | |
var ( | |
ErrCorrupted = errors.New("token corrupted") | |
ErrAlg = errors.New("token uses a wrong algorithm") | |
ErrExpired = errors.New("token expired") | |
ErrNotYetValid = errors.New("token not valid yet") | |
ErrTooLong = errors.New("token too long") | |
ErrSignature = errors.New("token signature invalid") | |
ErrSignatureCorrupted = errors.New("token signature corrupted") | |
MaxTokenLen = 8000 // MaxTokenLen is the safety limit to avoid decoding very long data | |
ExpectedClaimCount = 10 // ExpectedClaimCount is a starting size for the claims map | |
) | |
type Algorithm string | |
const ( | |
TokenID = "jti" // TokenID is a unique identifier for this token. | |
Issuer = "iss" // Issuer is the principal that issued the token | |
Audience = "aud" // Audience identifies the recipents the token is intended for | |
Subject = "sub" // Subject is the user/account /etc that this token authorizes access to | |
IssuedAt = "iat" // IssuedAt is a Unix timestamp for when the token was issued | |
ExpiresAt = "exp" // ExpiresAt is a Unix timestamp for when the token should expire | |
NotBeforeAt = "nbf" // NotBeforeAt is a timestamp this token should not be accepted until | |
Alg = "alg" // Alg is a header field identifying the signing algorithm | |
Typ = "typ" // Typ is a header field that must be set to "JWT" | |
KeyID = "kid" // KeyID is a header field, an opaque string identifying the key used | |
Forever time.Duration = 1<<63 - 1 // Forever is validity duration of tokens that do not expire | |
stackClaimsSpace = 512 | |
hs256SignatureEncLen = 43 // RawURLEncoding.EncodedLen(sha256.Size) | |
hs256Header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" // {"alg":"HS256","typ":"JWT"} | |
jwtTyp = "JWT" | |
HS256 Algorithm = "HS256" | |
MinHS256KeyLen = 32 | |
MaxHS256KeyLen = 64 // anything longer is hashed to 32 bytes | |
) | |
type Claims map[string]any | |
func New(subject string, validity time.Duration) Claims { | |
return NewAt(subject, validity, time.Now()) | |
} | |
func NewAt(subject string, validity time.Duration, now time.Time) Claims { | |
if validity == 0 { | |
// accepting 0 would allow a misconfiguration to escalate into a security issue | |
panic("zero validity is invalid, use Forever for non-expiring tokens") | |
} | |
c := make(Claims, ExpectedClaimCount) | |
c[IssuedAt] = now.Unix() | |
if validity != Forever { | |
c[ExpiresAt] = now.Add(validity).Unix() | |
} | |
if subject != "" { | |
c[Subject] = subject | |
} | |
return c | |
} | |
func (c Claims) String(key string) string { | |
if v, ok := c[key].(string); ok { | |
return v | |
} else { | |
return "" | |
} | |
} | |
func (c Claims) Int64(key string) (int64, bool) { | |
switch v := c[key].(type) { | |
case json.Number: | |
n, err := v.Int64() | |
if err != nil { | |
return 0, false | |
} | |
return n, true | |
case int: | |
return int64(v), true | |
case int64: | |
return v, true | |
case float64: | |
return int64(v), true | |
default: | |
return 0, false | |
} | |
} | |
func (c Claims) Time(key string) time.Time { | |
if v, ok := c.Int64(key); ok && v != 0 { | |
return time.Unix(v, 0) | |
} else { | |
return time.Time{} | |
} | |
} | |
func (c Claims) ExpiresAt() time.Time { | |
return c.Time(ExpiresAt) | |
} | |
func (c Claims) Issuer() string { | |
return c.String(Issuer) | |
} | |
func (c Claims) Subject() string { | |
return c.String(Subject) | |
} | |
func (c Claims) TokenID() string { | |
return c.String(TokenID) | |
} | |
func (c Claims) ValidateTime(tolerance time.Duration) error { | |
return c.ValidateTimeAt(tolerance, time.Now()) | |
} | |
func (c Claims) ValidateTimeAt(tolerance time.Duration, now time.Time) error { | |
if exp := c.ExpiresAt(); !exp.IsZero() { | |
if now.After(exp.Add(tolerance)) { | |
return ErrExpired | |
} | |
} | |
if exp := c.Time(NotBeforeAt); !exp.IsZero() { | |
if now.Before(exp.Add(-tolerance)) { | |
return ErrNotYetValid | |
} | |
} | |
return nil | |
} | |
type header struct { | |
Alg string `json:"alg"` | |
Typ string `json:"typ"` | |
KeyID string `json:"kid"` | |
} | |
func SignHS256String(claims, headerClaims Claims, key []byte) string { | |
b := SignHS256(claims, headerClaims, key, nil) | |
return unsafe.String(&b[0], len(b)) | |
} | |
// SignHS256 produces a signed JWT token from the given claims. | |
func SignHS256(claims, headerClaims Claims, key []byte, buf []byte) []byte { | |
if len(key) == 0 { | |
panic("missing key") | |
} | |
rawClaims, err := json.Marshal(claims) | |
if err != nil { | |
panic(err) | |
} | |
var rawHeader []byte | |
if headerClaims != nil { | |
headerClaims[Typ] = jwtTyp | |
headerClaims[Alg] = string(HS256) | |
rawHeader, err = json.Marshal(headerClaims) | |
if err != nil { | |
panic(err) | |
} | |
rawHeader = []byte(base64.RawURLEncoding.EncodeToString(rawHeader)) | |
} | |
return SignHS256Raw(rawClaims, rawHeader, key, buf) | |
} | |
// SignHS256Raw produces a signed JWT token from the given raw claims. | |
func SignHS256Raw(claims, header []byte, key []byte, buf []byte) []byte { | |
headerLen := len(header) | |
if headerLen == 0 { | |
headerLen = len(hs256Header) | |
} | |
claimLen := base64.RawURLEncoding.EncodedLen(len(claims)) | |
tokenLen := headerLen + 1 + claimLen + 1 + hs256SignatureEncLen | |
if len(buf) < tokenLen { | |
buf = make([]byte, tokenLen) | |
} | |
if len(header) == 0 { | |
copy(buf, hs256Header) | |
} else { | |
copy(buf, header) | |
} | |
buf[headerLen] = '.' | |
base64.RawURLEncoding.Encode(buf[headerLen+1:], claims) | |
var hash [sha256.Size]byte | |
alg := hmac.New(sha256.New, key) | |
alg.Write(buf[:headerLen+1+claimLen]) | |
alg.Sum(hash[:0]) | |
buf[headerLen+1+claimLen] = '.' | |
base64.RawURLEncoding.Encode(buf[headerLen+1+claimLen+1:], hash[:]) | |
return buf | |
} | |
// Token is a result of parsing a JWT token. | |
type Token struct { | |
claims Claims | |
alg Algorithm | |
keyID string | |
dataToSign []byte | |
sig []byte | |
} | |
func (t *Token) Claims() Claims { | |
return t.claims | |
} | |
func (t *Token) Alg() Algorithm { | |
return t.alg | |
} | |
func (t *Token) KeyID() string { | |
if t.keyID != "" { | |
return t.keyID | |
} | |
return t.claims.String(KeyID) | |
} | |
// ParseString decodes JWT parts of a token. | |
func ParseString(rawToken string) (Token, error) { | |
var token Token | |
err := token.ParseString(rawToken) | |
return token, err | |
} | |
// Parse decodes JWT parts of a token. | |
func Parse(rawToken []byte) (Token, error) { | |
var token Token | |
err := token.Parse(rawToken) | |
return token, err | |
} | |
func DecodeHS256String(rawToken string, tolerance time.Duration, key []byte) (Claims, error) { | |
return DecodeHS256StringAt(rawToken, key, tolerance, time.Now()) | |
} | |
func DecodeHS256StringAt(rawToken string, key []byte, tolerance time.Duration, now time.Time) (Claims, error) { | |
var token Token | |
err := token.ParseString(rawToken) | |
if err != nil { | |
return nil, err | |
} | |
err = token.ValidateHS256(key) | |
if err != nil { | |
return nil, err | |
} | |
c := token.Claims() | |
err = c.ValidateTimeAt(tolerance, now) | |
return c, err | |
} | |
// ParseString decodes JWT parts of a token. | |
func (token *Token) ParseString(rawToken string) error { | |
return token.Parse(unsafe.Slice(unsafe.StringData(rawToken), len(rawToken))) | |
} | |
// Parse decodes JWT parts of a token. | |
func (token *Token) Parse(rawToken []byte) error { | |
if len(rawToken) > MaxTokenLen { | |
return ErrTooLong | |
} | |
i1 := bytes.IndexByte(rawToken, '.') | |
if i1 < 0 { | |
return ErrCorrupted | |
} | |
i2 := bytes.IndexByte(rawToken[i1+1:], '.') | |
if i2 < 0 { | |
return ErrCorrupted | |
} | |
i2 += i1 + 1 | |
h := rawToken[:i1] | |
if string(h) == hs256Header { | |
token.alg = HS256 | |
} else { | |
dbuf := make([]byte, base64.RawURLEncoding.DecodedLen(len(h))) | |
n, err := base64.RawURLEncoding.Decode(dbuf, h) | |
if err != nil { | |
return ErrCorrupted | |
} | |
var hdr header | |
err = json.Unmarshal(dbuf[:n], &hdr) | |
if err != nil { | |
return ErrCorrupted | |
} | |
if hdr.Typ != jwtTyp { | |
return ErrCorrupted | |
} | |
token.alg = Algorithm(hdr.Alg) | |
token.keyID = hdr.KeyID | |
} | |
token.sig = rawToken[i2+1:] | |
token.dataToSign = rawToken[:i2] | |
c := make(Claims, ExpectedClaimCount) | |
{ | |
raw := rawToken[i1+1 : i2] | |
n := base64.RawURLEncoding.DecodedLen(len(raw)) | |
// if claims data is small enough, decode into a stack buffer to avoid allocation | |
var stackBuf [stackClaimsSpace]byte | |
var buf []byte | |
if n < cap(stackBuf) { | |
buf = stackBuf[:] | |
} else { | |
buf = make([]byte, n) | |
} | |
// log.Printf("RawToken = %q", raw) | |
n, err := base64.RawURLEncoding.Decode(buf, raw) | |
if err != nil { | |
return ErrCorrupted | |
} | |
// log.Printf("JSONToken = %s", buf[:n]) | |
dec := json.NewDecoder(bytes.NewReader(buf[:n])) | |
dec.UseNumber() | |
err = dec.Decode(&c) | |
if err != nil { | |
// log.Printf("JSON err: %v", err) | |
return ErrCorrupted | |
} | |
} | |
token.claims = c | |
return nil | |
} | |
func (token *Token) Validate(alg Algorithm, key []byte) error { | |
switch alg { | |
case HS256: | |
return token.ValidateHS256(key) | |
default: | |
panic("unsupported algorithm") | |
} | |
} | |
func (token *Token) ValidateHS256(key []byte) error { | |
if len(key) == 0 { | |
panic("missing key") | |
} | |
if token.alg != HS256 { | |
return ErrAlg | |
} | |
var actualHash, expectedHash [sha256.Size]byte | |
if base64.RawURLEncoding.DecodedLen(len(token.sig)) != len(actualHash) { | |
// log.Printf("base64.RawURLEncoding.DecodedLen(len(token.sig)) %d != len(actualHash) %d", base64.RawURLEncoding.DecodedLen(len(token.sig)), len(actualHash)) | |
return ErrSignatureCorrupted | |
} | |
n, err := base64.RawURLEncoding.Decode(actualHash[:], token.sig) | |
if err != nil || n != len(actualHash) { | |
return ErrSignatureCorrupted | |
} | |
alg := hmac.New(sha256.New, key) | |
alg.Write(token.dataToSign) | |
alg.Sum(expectedHash[:0]) | |
// log.Printf("StringToSign = %q", token[:i2]) | |
// log.Printf("expectedHash = %q", base64.RawURLEncoding.EncodeToString(expectedHash[:])) | |
// log.Printf("actualHash = %q", base64.RawURLEncoding.EncodeToString(actualHash[:])) | |
if 1 != subtle.ConstantTimeCompare(actualHash[:], expectedHash[:]) { | |
return ErrSignature | |
} | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jwt | |
import ( | |
"encoding/json" | |
"testing" | |
"time" | |
) | |
var testKey = []byte("helloworld") | |
var testNow = time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC) | |
func TestRoundTrip(t *testing.T) { | |
tests := []struct { | |
name string | |
input string | |
payload string | |
newToken string | |
}{ | |
{ | |
"jwt.io", | |
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.dD_HjcF4ZoXwMj6Ov7q7uDqCZLeNMhOwC52WEGEG7P0", | |
`{"iat":1516239022,"name":"John Doe","sub":"1234567890"}`, | |
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MTYyMzkwMjIsIm5hbWUiOiJKb2huIERvZSIsInN1YiI6IjEyMzQ1Njc4OTAifQ.pBtuSBkUUz0-RMxWpH-uWr-4_C-AJiImWHiE7zxbcI4", | |
}, | |
{ | |
"jwt.io re-encoded via this library", | |
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MTYyMzkwMjIsIm5hbWUiOiJKb2huIERvZSIsInN1YiI6IjEyMzQ1Njc4OTAifQ.pBtuSBkUUz0-RMxWpH-uWr-4_C-AJiImWHiE7zxbcI4", | |
`{"iat":1516239022,"name":"John Doe","sub":"1234567890"}`, | |
"", | |
}, | |
} | |
for _, test := range tests { | |
t.Run(test.name, func(t *testing.T) { | |
var token Token | |
err := token.ParseString(test.input) | |
if err != nil { | |
t.Fatal(err) | |
} | |
err = token.ValidateHS256(testKey) | |
if err != nil { | |
t.Fatal(err) | |
} | |
c := token.Claims() | |
p := string(must(json.Marshal(c))) | |
if p != test.payload { | |
t.Errorf("DecodeHS256 = %s, wanted %s", p, test.payload) | |
} | |
err = c.ValidateTimeAt(0, testNow) | |
if err != nil { | |
t.Fatal(err) | |
} | |
output := SignHS256String(c, nil, testKey) | |
if test.newToken == "" { | |
test.newToken = test.input | |
} | |
if output != test.newToken { | |
t.Errorf("** SignHS256 = %q, wanted %q", output, test.newToken) | |
} | |
}) | |
} | |
} | |
func must[T any](v T, err error) T { | |
if err != nil { | |
panic(err) | |
} | |
return v | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment