Last active
February 23, 2025 09:23
-
-
Save notdodo/a1bfa0a7df21b26f2a3436731a28ed48 to your computer and use it in GitHub Desktop.
Validation of "x-amzn-oidc-data" token from AWS ALB OIDC (JWT)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Authors: [notdodo] | |
package main | |
import ( | |
"crypto/ecdsa" | |
"crypto/sha256" | |
"crypto/x509" | |
"encoding/base64" | |
"encoding/pem" | |
"fmt" | |
"log" | |
"math/big" | |
"strings" | |
) | |
func main() { | |
// TODO: validate exp, signer and fetch the public key using kid in the header | |
jwtStr := "x-amzn-oidc-data" | |
jwtParts := strings.Split(jwtStr, ".") | |
fmt.Println("Parts Count:", len(jwtParts)) | |
if len(jwtParts) != 3 { | |
fmt.Println("Invalid JWT: expected 3 parts") | |
return | |
} | |
header, err := base64.URLEncoding.DecodeString(jwtParts[0]) | |
if err != nil { | |
fmt.Println("Error decoding header:", err) | |
return | |
} | |
fmt.Println("Header:", string(header)) | |
payload, err := base64.URLEncoding.DecodeString(jwtParts[1]) | |
if err != nil { | |
fmt.Println("Error decoding payload:", err) | |
return | |
} | |
fmt.Println("Payload:", string(payload)) | |
signature, err := base64.URLEncoding.DecodeString(jwtParts[2]) | |
if err != nil { | |
fmt.Println("Error decoding signature:", err) | |
return | |
} | |
fmt.Println("Signature:", signature, len(signature)) | |
if len(signature) != 64 { | |
log.Panic(fmt.Sprintf("Invalid signature length: got %d, expected 64", len(signature))) | |
} | |
header_b64URL := base64.URLEncoding.EncodeToString(header) | |
payload_b64URL := base64.URLEncoding.EncodeToString(payload) | |
hashInput := strings.Join([]string{header_b64URL, payload_b64URL}, ".") | |
fmt.Println("input for hash", hashInput) | |
digest := sha256.Sum256([]byte(hashInput)) | |
fmt.Println("digest: ", digest) | |
pubPemKey := "-----BEGIN PUBLIC KEY-----\n<public-key>\n-----END PUBLIC KEY-----" | |
block, _ := pem.Decode([]byte(pubPemKey)) | |
if block == nil || block.Type != "PUBLIC KEY" { | |
panic("Failed to decode PEM public key") | |
} | |
pubEcdsaKey, err := x509.ParsePKIXPublicKey(block.Bytes) | |
if err != nil { | |
panic("Failed to parse ECDSA public key") | |
} | |
r := new(big.Int).SetBytes(signature[:32]) | |
s := new(big.Int).SetBytes(signature[32:]) | |
ecdsaPubKey, ok := pubEcdsaKey.(*ecdsa.PublicKey) | |
if !ok { | |
panic("Invalid ECDSA public key") | |
} | |
if !ecdsa.Verify(ecdsaPubKey, digest[:], r, s) { | |
panic("Signature verification failed") | |
} | |
fmt.Println("Signature verification successful") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Authors: [notdodo, markusl] | |
use base64::{engine::general_purpose::URL_SAFE, Engine as _}; | |
use p256::{ | |
ecdsa::{signature::hazmat::PrehashVerifier, Signature as P256Signature, VerifyingKey}, | |
pkcs8::DecodePublicKey, | |
}; | |
use serde::Deserialize; | |
use sha2::{Digest, Sha256}; | |
use std::{ | |
convert::TryFrom, | |
error::Error, | |
fmt::{self, Display, Formatter}, | |
time::{SystemTime, UNIX_EPOCH}, | |
}; | |
/// Represents an error that can occur during token verification. | |
#[derive(Debug)] | |
pub enum TokenError { | |
InvalidFormat(&'static str), | |
Base64DecodeError(base64::DecodeError), | |
JsonError(serde_json::Error), | |
ExpiredToken(u64, u64), | |
ReqwestError(reqwest::Error), | |
PublicKeyParseError, | |
SignatureLengthError(usize), | |
SignatureParseError, | |
SignatureVerificationFailed, | |
} | |
impl Display for TokenError { | |
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { | |
match self { | |
TokenError::InvalidFormat(msg) => write!(f, "Invalid token format: {}", msg), | |
TokenError::Base64DecodeError(err) => write!(f, "Base64 decode error: {}", err), | |
TokenError::JsonError(err) => write!(f, "JSON parse error: {}", err), | |
TokenError::ExpiredToken(exp, now) => { | |
write!(f, "Token is expired: exp={} <= now={}", exp, now) | |
} | |
TokenError::ReqwestError(err) => write!(f, "HTTP request error: {}", err), | |
TokenError::PublicKeyParseError => write!(f, "Failed to parse PEM public key"), | |
TokenError::SignatureLengthError(len) => { | |
write!(f, "Invalid signature length: expected 64, got {}", len) | |
} | |
TokenError::SignatureParseError => { | |
write!(f, "Failed to parse signature from raw (r,s)") | |
} | |
TokenError::SignatureVerificationFailed => write!(f, "Signature verification failed"), | |
} | |
} | |
} | |
impl Error for TokenError { | |
fn source(&self) -> Option<&(dyn Error + 'static)> { | |
match self { | |
TokenError::Base64DecodeError(err) => Some(err), | |
TokenError::JsonError(err) => Some(err), | |
TokenError::ReqwestError(err) => Some(err), | |
_ => None, | |
} | |
} | |
} | |
#[derive(Debug, Deserialize, PartialEq)] | |
pub struct Header { | |
pub kid: String, | |
pub signer: String, | |
pub exp: u64, | |
} | |
#[allow(dead_code)] | |
#[derive(Debug, Deserialize, PartialEq)] | |
pub struct Claims { | |
pub sub: String, | |
pub name: String, | |
pub preferred_username: String, | |
} | |
pub(crate) fn verify_jwt( | |
jwt_str: &str, | |
ignore_expired: bool, | |
) -> Result<(Header, Claims), TokenError> { | |
// 1) Split token into 3 parts | |
let parts: Vec<&str> = jwt_str.split('.').collect(); | |
if parts.len() != 3 { | |
return Err(TokenError::InvalidFormat("expected 3 parts")); | |
} | |
// 2) Decode and parse the header | |
let header_bytes = URL_SAFE | |
.decode(parts[0]) | |
.map_err(TokenError::Base64DecodeError)?; | |
let header: Header = serde_json::from_slice(&header_bytes).map_err(TokenError::JsonError)?; | |
// Check token expiration. | |
let now = SystemTime::now() | |
.duration_since(UNIX_EPOCH) | |
.map_err(|_| TokenError::InvalidFormat("system time before UNIX epoch"))? | |
.as_secs(); | |
if header.exp <= now && !ignore_expired { | |
return Err(TokenError::ExpiredToken(header.exp, now)); | |
} | |
// 3) Decode and parse the payload | |
let payload_bytes = URL_SAFE | |
.decode(parts[1]) | |
.map_err(TokenError::Base64DecodeError)?; | |
let claims: Claims = serde_json::from_slice(&payload_bytes).map_err(TokenError::JsonError)?; | |
// 4) Decode the signature | |
let signature_bytes = URL_SAFE | |
.decode(parts[2]) | |
.map_err(TokenError::Base64DecodeError)?; | |
// Signature MUST be 64 bytes | |
if signature_bytes.len() != 64 { | |
return Err(TokenError::SignatureLengthError(signature_bytes.len())); | |
} | |
// 5) Re-encode header and payload to base64-url. | |
let header_b64url = URL_SAFE.encode(&header_bytes); | |
let payload_b64url = URL_SAFE.encode(&payload_bytes); | |
let hash_input = format!("{}.{}", header_b64url, payload_b64url); | |
// 6) Compute SHA-256 of that exact input. | |
let digest = Sha256::digest(hash_input.as_bytes()); | |
// 7) Extract region from the `signer` field (ARN format) | |
// For example: arn:aws:elasticloadbalancing:us-east-1:...:loadbalancer/app/... | |
let arn_parts: Vec<&str> = header.signer.split(':').collect(); | |
if arn_parts.len() < 6 { | |
return Err(TokenError::InvalidFormat( | |
"signer is not a valid ARN-like string", | |
)); | |
} | |
let region = arn_parts[3]; // The 4th colon-delimited part is the region | |
// 8) Fetch public key from the ALB/ELB public key endpoint | |
let url = format!( | |
"https://public-keys.auth.elb.{}.amazonaws.com/{}", | |
region, header.kid | |
); | |
let pem_key = reqwest::blocking::get(&url) | |
.map_err(TokenError::ReqwestError)? | |
.error_for_status() | |
.map_err(TokenError::ReqwestError)? | |
.text() | |
.map_err(TokenError::ReqwestError)?; | |
// 8) Parse the ECDSA public key | |
let verifying_key = | |
VerifyingKey::from_public_key_pem(&pem_key).map_err(|_| TokenError::PublicKeyParseError)?; | |
// 9) Generate the raw signature as (r, s). | |
let r_bytes = <[u8; 32]>::try_from(&signature_bytes[..32]) | |
.map_err(|_| TokenError::SignatureParseError)?; | |
let s_bytes = <[u8; 32]>::try_from(&signature_bytes[32..]) | |
.map_err(|_| TokenError::SignatureParseError)?; | |
let sig = P256Signature::from_scalars(r_bytes, s_bytes) | |
.map_err(|_| TokenError::SignatureParseError)?; | |
// 10) Verify the signature | |
verifying_key | |
.verify_prehash(&digest, &sig) | |
.map_err(|_| TokenError::SignatureVerificationFailed)?; | |
Ok((header, claims)) | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::{verify_jwt, Claims, Header}; | |
#[test] | |
fn test_verify_jwt1() { | |
// Test with a valid token | |
let valid_token = "add-token-here=="; | |
let result = verify_jwt(valid_token, true); | |
assert!(result.is_ok()); | |
let (header, claims) = result.unwrap(); | |
assert_eq!( | |
header, | |
Header { | |
kid: "<add-kid-here>".to_string(), | |
signer: "arn:aws:elasticloadbalancing:<region>:<account>:loadbalancer/<app>" | |
.to_string(), | |
exp: 1702884041, | |
} | |
); | |
assert_eq!( | |
claims, | |
Claims { | |
sub: "<sub>".to_string(), | |
name: "<name>".to_string(), | |
preferred_username: "<preferred_username>".to_string(), | |
} | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is an example
jwt.rs
that also extracts the region and a placeholder test: