Skip to content

Instantly share code, notes, and snippets.

@miguelmota
Created February 11, 2019 21:26
Show Gist options
  • Save miguelmota/06f563756448b0d4ce2ba508b3cbe6e2 to your computer and use it in GitHub Desktop.
Save miguelmota/06f563756448b0d4ce2ba508b3cbe6e2 to your computer and use it in GitHub Desktop.
Golang AWS Cognito Validate JWT token
package auth
import (
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"math/big"
"net/http"
jwt "github.com/dgrijalva/jwt-go"
)
// Auth ...
type Auth struct {
jwk *JWK
jwkURL string
cognitoRegion string
cognitoUserPoolID string
}
// Config ...
type Config struct {
CognitoRegion string
CognitoUserPoolID string
}
// JWK ...
type JWK struct {
Keys []struct {
Alg string `json:"alg"`
E string `json:"e"`
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
} `json:"keys"`
}
// NewAuth ...
func NewAuth(config *Config) *Auth {
a := &Auth{
cognitoRegion: config.CognitoRegion,
cognitoUserPoolID: config.CognitoUserPoolID,
}
a.jwkURL = fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", a.cognitoRegion, a.cognitoUserPoolID)
err := a.CacheJWK()
if err != nil {
log.Fatal(err)
}
return a
}
// CacheJWK ...
func (a *Auth) CacheJWK() error {
req, err := http.NewRequest("GET", a.jwkURL, nil)
if err != nil {
return err
}
req.Header.Add("Accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
jwk := new(JWK)
err = json.Unmarshal(body, jwk)
if err != nil {
return err
}
a.jwk = jwk
return nil
}
// ParseJWT ...
func (a *Auth) ParseJWT(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
key := convertKey(a.jwk.Keys[1].E, a.jwk.Keys[1].N)
return key, nil
})
if err != nil {
return token, err
}
return token, nil
}
// JWK ...
func (a *Auth) JWK() *JWK {
return a.jwk
}
// JWKURL ...
func (a *Auth) JWKURL() string {
return a.jwkURL
}
// https://gist.github.com/MathieuMailhos/361f24316d2de29e8d41e808e0071b13
func convertKey(rawE, rawN string) *rsa.PublicKey {
decodedE, err := base64.RawURLEncoding.DecodeString(rawE)
if err != nil {
panic(err)
}
if len(decodedE) < 4 {
ndata := make([]byte, 4)
copy(ndata[4-len(decodedE):], decodedE)
decodedE = ndata
}
pubKey := &rsa.PublicKey{
N: &big.Int{},
E: int(binary.BigEndian.Uint32(decodedE[:])),
}
decodedN, err := base64.RawURLEncoding.DecodeString(rawN)
if err != nil {
panic(err)
}
pubKey.N.SetBytes(decodedN)
return pubKey
}
package auth
import (
"os"
"testing"
)
func TestCacheJWT(t *testing.T) {
if !(os.Getenv("AWS_COGNITO_USER_POOL_ID") != "" && os.Getenv("AWS_COGNITO_REGION") != "") {
t.Skip("requires AWS Cognito environment variables")
}
auth := NewAuth(&Config{
CognitoRegion: os.Getenv("AWS_COGNITO_REGION"),
CognitoUserPoolID: os.Getenv("AWS_COGNITO_USER_POOL_ID"),
})
err := auth.CacheJWK()
if err != nil {
t.Error(err)
}
jwt := "eyJraWQiOiJlS3lvdytnb1wvXC9yWmtkbGFhRFNOM25jTTREd0xTdFhibks4TTB5b211aE09IiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiJjMTcxOGY3OC00ODY5LTRmMmEtYTk2ZS1lYmEwYmJkY2RkMjEiLCJldmVudF9pZCI6IjZmYWMyZGNjLTJlMzUtMTFlOS05NDZjLTZiZDI0YmRlZjFiNiIsInRva2VuX3VzZSI6ImFjY2VzcyIsInNjb3BlIjoiYXdzLmNvZ25pdG8uc2lnbmluLnVzZXIuYWRtaW4iLCJhdXRoX3RpbWUiOjE1NDk5MTQyNjUsImlzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC51cy13ZXN0LTIuYW1hem9uYXdzLmNvbVwvdXMtd2VzdC0yX0wwVldGSEVueSIsImV4cCI6MTU0OTkxNzg2NSwiaWF0IjoxNTQ5OTE0MjY1LCJqdGkiOiIzMTg0MDdkMC0zZDNhLTQ0NDItOTMyYy1lY2I0MjQ2MzRiYjIiLCJjbGllbnRfaWQiOiI2ZjFzcGI2MzZwdG4wNzRvbjBwZGpnbms4bCIsInVzZXJuYW1lIjoiYzE3MThmNzgtNDg2OS00ZjJhLWE5NmUtZWJhMGJiZGNkZDIxIn0.rJl9mdCrw_lertWhC5RiJcfhRP-xwTYkPLPXmi_NQEO-LtIJ-kwVEvUaZsPnBXku3bWBM3V35jdJloiXclbffl4SDLVkkvU9vzXDETAMaZEzOY1gDVcg4YzNNR4H5kHnl-G-XiN5MajgaWbjohDHTvbPnqgW7e_4qNVXueZv2qfQ8hZ_VcyniNxMGaui-C0_YuR6jdH-T14Wl59Cyf-UFEyli1NZFlmpUQ8QODGMUI12PVFOZiHJIOZ3CQM_Xs-TlRy53RlKGFzf6RQfRm57rJw_zLyJHHnB8DZgbdCRfhNsqZka7ZZUUAlS9aMzdmSc3pPFSJ-hH3p8eFAgB4E71g"
token, err := auth.ParseJWT(jwt)
if err != nil {
t.Error(err)
}
if !token.Valid {
t.Fail()
}
}
@nabarunchatterjee
Copy link

@miguelmota Thanks. This has been very helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment