-
-
Save miguelmota/06f563756448b0d4ce2ba508b3cbe6e2 to your computer and use it in GitHub Desktop.
package auth | |
import ( | |
"crypto/rsa" | |
"encoding/base64" | |
"encoding/binary" | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"math/big" | |
"net/http" | |
jwt "github.com/dgrijalva/jwt-go" | |
) | |
// Auth ... | |
type Auth struct { | |
jwk *JWK | |
jwkURL string | |
cognitoRegion string | |
cognitoUserPoolID string | |
} | |
// Config ... | |
type Config struct { | |
CognitoRegion string | |
CognitoUserPoolID string | |
} | |
// JWK ... | |
type JWK struct { | |
Keys []struct { | |
Alg string `json:"alg"` | |
E string `json:"e"` | |
Kid string `json:"kid"` | |
Kty string `json:"kty"` | |
N string `json:"n"` | |
} `json:"keys"` | |
} | |
// NewAuth ... | |
func NewAuth(config *Config) *Auth { | |
a := &Auth{ | |
cognitoRegion: config.CognitoRegion, | |
cognitoUserPoolID: config.CognitoUserPoolID, | |
} | |
a.jwkURL = fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", a.cognitoRegion, a.cognitoUserPoolID) | |
err := a.CacheJWK() | |
if err != nil { | |
log.Fatal(err) | |
} | |
return a | |
} | |
// CacheJWK ... | |
func (a *Auth) CacheJWK() error { | |
req, err := http.NewRequest("GET", a.jwkURL, nil) | |
if err != nil { | |
return err | |
} | |
req.Header.Add("Accept", "application/json") | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return err | |
} | |
defer resp.Body.Close() | |
body, err := ioutil.ReadAll(resp.Body) | |
if err != nil { | |
return err | |
} | |
jwk := new(JWK) | |
err = json.Unmarshal(body, jwk) | |
if err != nil { | |
return err | |
} | |
a.jwk = jwk | |
return nil | |
} | |
// ParseJWT ... | |
func (a *Auth) ParseJWT(tokenString string) (*jwt.Token, error) { | |
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | |
key := convertKey(a.jwk.Keys[1].E, a.jwk.Keys[1].N) | |
return key, nil | |
}) | |
if err != nil { | |
return token, err | |
} | |
return token, nil | |
} | |
// JWK ... | |
func (a *Auth) JWK() *JWK { | |
return a.jwk | |
} | |
// JWKURL ... | |
func (a *Auth) JWKURL() string { | |
return a.jwkURL | |
} | |
// https://gist.github.com/MathieuMailhos/361f24316d2de29e8d41e808e0071b13 | |
func convertKey(rawE, rawN string) *rsa.PublicKey { | |
decodedE, err := base64.RawURLEncoding.DecodeString(rawE) | |
if err != nil { | |
panic(err) | |
} | |
if len(decodedE) < 4 { | |
ndata := make([]byte, 4) | |
copy(ndata[4-len(decodedE):], decodedE) | |
decodedE = ndata | |
} | |
pubKey := &rsa.PublicKey{ | |
N: &big.Int{}, | |
E: int(binary.BigEndian.Uint32(decodedE[:])), | |
} | |
decodedN, err := base64.RawURLEncoding.DecodeString(rawN) | |
if err != nil { | |
panic(err) | |
} | |
pubKey.N.SetBytes(decodedN) | |
return pubKey | |
} |
package auth | |
import ( | |
"os" | |
"testing" | |
) | |
func TestCacheJWT(t *testing.T) { | |
if !(os.Getenv("AWS_COGNITO_USER_POOL_ID") != "" && os.Getenv("AWS_COGNITO_REGION") != "") { | |
t.Skip("requires AWS Cognito environment variables") | |
} | |
auth := NewAuth(&Config{ | |
CognitoRegion: os.Getenv("AWS_COGNITO_REGION"), | |
CognitoUserPoolID: os.Getenv("AWS_COGNITO_USER_POOL_ID"), | |
}) | |
err := auth.CacheJWK() | |
if err != nil { | |
t.Error(err) | |
} | |
jwt := "eyJraWQiOiJlS3lvdytnb1wvXC9yWmtkbGFhRFNOM25jTTREd0xTdFhibks4TTB5b211aE09IiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiJjMTcxOGY3OC00ODY5LTRmMmEtYTk2ZS1lYmEwYmJkY2RkMjEiLCJldmVudF9pZCI6IjZmYWMyZGNjLTJlMzUtMTFlOS05NDZjLTZiZDI0YmRlZjFiNiIsInRva2VuX3VzZSI6ImFjY2VzcyIsInNjb3BlIjoiYXdzLmNvZ25pdG8uc2lnbmluLnVzZXIuYWRtaW4iLCJhdXRoX3RpbWUiOjE1NDk5MTQyNjUsImlzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC51cy13ZXN0LTIuYW1hem9uYXdzLmNvbVwvdXMtd2VzdC0yX0wwVldGSEVueSIsImV4cCI6MTU0OTkxNzg2NSwiaWF0IjoxNTQ5OTE0MjY1LCJqdGkiOiIzMTg0MDdkMC0zZDNhLTQ0NDItOTMyYy1lY2I0MjQ2MzRiYjIiLCJjbGllbnRfaWQiOiI2ZjFzcGI2MzZwdG4wNzRvbjBwZGpnbms4bCIsInVzZXJuYW1lIjoiYzE3MThmNzgtNDg2OS00ZjJhLWE5NmUtZWJhMGJiZGNkZDIxIn0.rJl9mdCrw_lertWhC5RiJcfhRP-xwTYkPLPXmi_NQEO-LtIJ-kwVEvUaZsPnBXku3bWBM3V35jdJloiXclbffl4SDLVkkvU9vzXDETAMaZEzOY1gDVcg4YzNNR4H5kHnl-G-XiN5MajgaWbjohDHTvbPnqgW7e_4qNVXueZv2qfQ8hZ_VcyniNxMGaui-C0_YuR6jdH-T14Wl59Cyf-UFEyli1NZFlmpUQ8QODGMUI12PVFOZiHJIOZ3CQM_Xs-TlRy53RlKGFzf6RQfRm57rJw_zLyJHHnB8DZgbdCRfhNsqZka7ZZUUAlS9aMzdmSc3pPFSJ-hH3p8eFAgB4E71g" | |
token, err := auth.ParseJWT(jwt) | |
if err != nil { | |
t.Error(err) | |
} | |
if !token.Valid { | |
t.Fail() | |
} | |
} |
@miguelmota thanks for posting this, it was super helpful!
I did run in to one problem with the code - the KeyFunc is hardcoding using the second JWK, instead of looking up the correct JWK by KID. On line 91:
key := convertKey(a.jwk.Keys[1].E, a.jwk.Keys[1].N)
This was failing for me, because my JWT happened to use the first KID instead of the second! So I created a new gist that fixes this by looking up the JWK by KID, and also uses https://github.com/golang-jwt/jwt as mentioned by @Quixotical. Feel free to copy if you like: https://gist.github.com/stream-ai-llc/2c27416d01c99cbeea9fdd07d74e8b0f.
But thanks again, this gist really helped!
There is a problem with this
key := convertKey(a.jwk.Keys[1].E, a.jwk.Keys[1].N)
Assumes the key is always the last key. You need to match the kid. Replace with
index := 0
found := false
for i, v := range a.jwk.Keys {
if v.Kid == token.Header["kid"] {
index = i
found = true
}
}
if !found {
return nil, errors.New("Key Not found")
}
if token.Method.Alg() != "RS256" {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
key := convertKey(a.jwk.Keys[index].E, a.jwk.Keys[index].N)
@miguelmota Thanks. This has been very helpful.
one thing to note for anyone who comes here though is that the JWT package used here is now deprecated, you should instead use https://github.com/golang-jwt/jwt. This code is fully functional using the new repo though