Skip to content

Instantly share code, notes, and snippets.

@notdodo
Last active February 23, 2025 09:23
Show Gist options
  • Save notdodo/a1bfa0a7df21b26f2a3436731a28ed48 to your computer and use it in GitHub Desktop.
Save notdodo/a1bfa0a7df21b26f2a3436731a28ed48 to your computer and use it in GitHub Desktop.
Validation of "x-amzn-oidc-data" token from AWS ALB OIDC (JWT)
// Authors: [notdodo]
package main
import (
"crypto/ecdsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"log"
"math/big"
"strings"
)
func main() {
// TODO: validate exp, signer and fetch the public key using kid in the header
jwtStr := "x-amzn-oidc-data"
jwtParts := strings.Split(jwtStr, ".")
fmt.Println("Parts Count:", len(jwtParts))
if len(jwtParts) != 3 {
fmt.Println("Invalid JWT: expected 3 parts")
return
}
header, err := base64.URLEncoding.DecodeString(jwtParts[0])
if err != nil {
fmt.Println("Error decoding header:", err)
return
}
fmt.Println("Header:", string(header))
payload, err := base64.URLEncoding.DecodeString(jwtParts[1])
if err != nil {
fmt.Println("Error decoding payload:", err)
return
}
fmt.Println("Payload:", string(payload))
signature, err := base64.URLEncoding.DecodeString(jwtParts[2])
if err != nil {
fmt.Println("Error decoding signature:", err)
return
}
fmt.Println("Signature:", signature, len(signature))
if len(signature) != 64 {
log.Panic(fmt.Sprintf("Invalid signature length: got %d, expected 64", len(signature)))
}
header_b64URL := base64.URLEncoding.EncodeToString(header)
payload_b64URL := base64.URLEncoding.EncodeToString(payload)
hashInput := strings.Join([]string{header_b64URL, payload_b64URL}, ".")
fmt.Println("input for hash", hashInput)
digest := sha256.Sum256([]byte(hashInput))
fmt.Println("digest: ", digest)
pubPemKey := "-----BEGIN PUBLIC KEY-----\n<public-key>\n-----END PUBLIC KEY-----"
block, _ := pem.Decode([]byte(pubPemKey))
if block == nil || block.Type != "PUBLIC KEY" {
panic("Failed to decode PEM public key")
}
pubEcdsaKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
panic("Failed to parse ECDSA public key")
}
r := new(big.Int).SetBytes(signature[:32])
s := new(big.Int).SetBytes(signature[32:])
ecdsaPubKey, ok := pubEcdsaKey.(*ecdsa.PublicKey)
if !ok {
panic("Invalid ECDSA public key")
}
if !ecdsa.Verify(ecdsaPubKey, digest[:], r, s) {
panic("Signature verification failed")
}
fmt.Println("Signature verification successful")
}
// Authors: [notdodo, markusl]
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
use p256::{
ecdsa::{signature::hazmat::PrehashVerifier, Signature as P256Signature, VerifyingKey},
pkcs8::DecodePublicKey,
};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::{
convert::TryFrom,
error::Error,
fmt::{self, Display, Formatter},
time::{SystemTime, UNIX_EPOCH},
};
/// Represents an error that can occur during token verification.
#[derive(Debug)]
pub enum TokenError {
InvalidFormat(&'static str),
Base64DecodeError(base64::DecodeError),
JsonError(serde_json::Error),
ExpiredToken(u64, u64),
ReqwestError(reqwest::Error),
PublicKeyParseError,
SignatureLengthError(usize),
SignatureParseError,
SignatureVerificationFailed,
}
impl Display for TokenError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
TokenError::InvalidFormat(msg) => write!(f, "Invalid token format: {}", msg),
TokenError::Base64DecodeError(err) => write!(f, "Base64 decode error: {}", err),
TokenError::JsonError(err) => write!(f, "JSON parse error: {}", err),
TokenError::ExpiredToken(exp, now) => {
write!(f, "Token is expired: exp={} <= now={}", exp, now)
}
TokenError::ReqwestError(err) => write!(f, "HTTP request error: {}", err),
TokenError::PublicKeyParseError => write!(f, "Failed to parse PEM public key"),
TokenError::SignatureLengthError(len) => {
write!(f, "Invalid signature length: expected 64, got {}", len)
}
TokenError::SignatureParseError => {
write!(f, "Failed to parse signature from raw (r,s)")
}
TokenError::SignatureVerificationFailed => write!(f, "Signature verification failed"),
}
}
}
impl Error for TokenError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
TokenError::Base64DecodeError(err) => Some(err),
TokenError::JsonError(err) => Some(err),
TokenError::ReqwestError(err) => Some(err),
_ => None,
}
}
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct Header {
pub kid: String,
pub signer: String,
pub exp: u64,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize, PartialEq)]
pub struct Claims {
pub sub: String,
pub name: String,
pub preferred_username: String,
}
pub(crate) fn verify_jwt(
jwt_str: &str,
ignore_expired: bool,
) -> Result<(Header, Claims), TokenError> {
// 1) Split token into 3 parts
let parts: Vec<&str> = jwt_str.split('.').collect();
if parts.len() != 3 {
return Err(TokenError::InvalidFormat("expected 3 parts"));
}
// 2) Decode and parse the header
let header_bytes = URL_SAFE
.decode(parts[0])
.map_err(TokenError::Base64DecodeError)?;
let header: Header = serde_json::from_slice(&header_bytes).map_err(TokenError::JsonError)?;
// Check token expiration.
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| TokenError::InvalidFormat("system time before UNIX epoch"))?
.as_secs();
if header.exp <= now && !ignore_expired {
return Err(TokenError::ExpiredToken(header.exp, now));
}
// 3) Decode and parse the payload
let payload_bytes = URL_SAFE
.decode(parts[1])
.map_err(TokenError::Base64DecodeError)?;
let claims: Claims = serde_json::from_slice(&payload_bytes).map_err(TokenError::JsonError)?;
// 4) Decode the signature
let signature_bytes = URL_SAFE
.decode(parts[2])
.map_err(TokenError::Base64DecodeError)?;
// Signature MUST be 64 bytes
if signature_bytes.len() != 64 {
return Err(TokenError::SignatureLengthError(signature_bytes.len()));
}
// 5) Re-encode header and payload to base64-url.
let header_b64url = URL_SAFE.encode(&header_bytes);
let payload_b64url = URL_SAFE.encode(&payload_bytes);
let hash_input = format!("{}.{}", header_b64url, payload_b64url);
// 6) Compute SHA-256 of that exact input.
let digest = Sha256::digest(hash_input.as_bytes());
// 7) Extract region from the `signer` field (ARN format)
// For example: arn:aws:elasticloadbalancing:us-east-1:...:loadbalancer/app/...
let arn_parts: Vec<&str> = header.signer.split(':').collect();
if arn_parts.len() < 6 {
return Err(TokenError::InvalidFormat(
"signer is not a valid ARN-like string",
));
}
let region = arn_parts[3]; // The 4th colon-delimited part is the region
// 8) Fetch public key from the ALB/ELB public key endpoint
let url = format!(
"https://public-keys.auth.elb.{}.amazonaws.com/{}",
region, header.kid
);
let pem_key = reqwest::blocking::get(&url)
.map_err(TokenError::ReqwestError)?
.error_for_status()
.map_err(TokenError::ReqwestError)?
.text()
.map_err(TokenError::ReqwestError)?;
// 8) Parse the ECDSA public key
let verifying_key =
VerifyingKey::from_public_key_pem(&pem_key).map_err(|_| TokenError::PublicKeyParseError)?;
// 9) Generate the raw signature as (r, s).
let r_bytes = <[u8; 32]>::try_from(&signature_bytes[..32])
.map_err(|_| TokenError::SignatureParseError)?;
let s_bytes = <[u8; 32]>::try_from(&signature_bytes[32..])
.map_err(|_| TokenError::SignatureParseError)?;
let sig = P256Signature::from_scalars(r_bytes, s_bytes)
.map_err(|_| TokenError::SignatureParseError)?;
// 10) Verify the signature
verifying_key
.verify_prehash(&digest, &sig)
.map_err(|_| TokenError::SignatureVerificationFailed)?;
Ok((header, claims))
}
#[cfg(test)]
mod tests {
use super::{verify_jwt, Claims, Header};
#[test]
fn test_verify_jwt1() {
// Test with a valid token
let valid_token = "add-token-here==";
let result = verify_jwt(valid_token, true);
assert!(result.is_ok());
let (header, claims) = result.unwrap();
assert_eq!(
header,
Header {
kid: "<add-kid-here>".to_string(),
signer: "arn:aws:elasticloadbalancing:<region>:<account>:loadbalancer/<app>"
.to_string(),
exp: 1702884041,
}
);
assert_eq!(
claims,
Claims {
sub: "<sub>".to_string(),
name: "<name>".to_string(),
preferred_username: "<preferred_username>".to_string(),
}
);
}
}
@markusl
Copy link

markusl commented Feb 22, 2025

Here is an example jwt.rs that also extracts the region and a placeholder test:

use base64::{engine::general_purpose::URL_SAFE, Engine as _};
use p256::{
    ecdsa::{signature::hazmat::PrehashVerifier, Signature as P256Signature, VerifyingKey},
    pkcs8::DecodePublicKey,
};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::{
    convert::TryFrom,
    error::Error,
    fmt::{self, Display, Formatter},
    time::{SystemTime, UNIX_EPOCH},
};

/// Represents an error that can occur during token verification.
#[derive(Debug)]
pub enum TokenError {
    InvalidFormat(&'static str),
    Base64DecodeError(base64::DecodeError),
    JsonError(serde_json::Error),
    ExpiredToken(u64, u64),
    ReqwestError(reqwest::Error),
    PublicKeyParseError,
    SignatureLengthError(usize),
    SignatureParseError,
    SignatureVerificationFailed,
}

impl Display for TokenError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            TokenError::InvalidFormat(msg) => write!(f, "Invalid token format: {}", msg),
            TokenError::Base64DecodeError(err) => write!(f, "Base64 decode error: {}", err),
            TokenError::JsonError(err) => write!(f, "JSON parse error: {}", err),
            TokenError::ExpiredToken(exp, now) => {
                write!(f, "Token is expired: exp={} <= now={}", exp, now)
            }
            TokenError::ReqwestError(err) => write!(f, "HTTP request error: {}", err),
            TokenError::PublicKeyParseError => write!(f, "Failed to parse PEM public key"),
            TokenError::SignatureLengthError(len) => {
                write!(f, "Invalid signature length: expected 64, got {}", len)
            }
            TokenError::SignatureParseError => {
                write!(f, "Failed to parse signature from raw (r,s)")
            }
            TokenError::SignatureVerificationFailed => write!(f, "Signature verification failed"),
        }
    }
}

impl Error for TokenError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            TokenError::Base64DecodeError(err) => Some(err),
            TokenError::JsonError(err) => Some(err),
            TokenError::ReqwestError(err) => Some(err),
            _ => None,
        }
    }
}

#[derive(Debug, Deserialize, PartialEq)]
pub struct Header {
    pub kid: String,
    pub signer: String,
    pub exp: u64,
}

#[allow(dead_code)]
#[derive(Debug, Deserialize, PartialEq)]
pub struct Claims {
    pub sub: String,
    pub name: String,
    pub preferred_username: String,
}

pub(crate) fn verify_jwt(jwt_str: &str, ignore_expired: bool) -> Result<(Header, Claims), TokenError> {
    // 1) Split token into 3 parts
    let parts: Vec<&str> = jwt_str.split('.').collect();
    if parts.len() != 3 {
        return Err(TokenError::InvalidFormat("expected 3 parts"));
    }

    // 2) Decode and parse the header
    let header_bytes = URL_SAFE
        .decode(parts[0])
        .map_err(TokenError::Base64DecodeError)?;
    let header: Header = serde_json::from_slice(&header_bytes).map_err(TokenError::JsonError)?;

    // Check token expiration.
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .map_err(|_| TokenError::InvalidFormat("system time before UNIX epoch"))?
        .as_secs();
    if header.exp <= now && !ignore_expired {
        return Err(TokenError::ExpiredToken(header.exp, now));
    }

    // 3) Decode and parse the payload
    let payload_bytes = URL_SAFE
        .decode(parts[1])
        .map_err(TokenError::Base64DecodeError)?;
    let claims: Claims = serde_json::from_slice(&payload_bytes).map_err(TokenError::JsonError)?;

    // 4) Decode the signature
    let signature_bytes = URL_SAFE
        .decode(parts[2])
        .map_err(TokenError::Base64DecodeError)?;

    // Signature MUST be 64 bytes
    if signature_bytes.len() != 64 {
        return Err(TokenError::SignatureLengthError(signature_bytes.len()));
    }

    // 5) Re-encode header and payload to base64-url.
    let header_b64url = URL_SAFE.encode(&header_bytes);
    let payload_b64url = URL_SAFE.encode(&payload_bytes);
    let hash_input = format!("{}.{}", header_b64url, payload_b64url);

    // 6) Compute SHA-256 of that exact input.
    let digest = Sha256::digest(hash_input.as_bytes());

    // 7) Extract region from the `signer` field (ARN format)
    //    For example: arn:aws:elasticloadbalancing:us-east-1:...:loadbalancer/app/...
    let arn_parts: Vec<&str> = header.signer.split(':').collect();
    if arn_parts.len() < 6 {
        return Err(TokenError::InvalidFormat("signer is not a valid ARN-like string"));
    }
    let region = arn_parts[3]; // The 4th colon-delimited part is the region

    // 8) Fetch public key from the ALB/ELB public key endpoint
    let url = format!(
      "https://public-keys.auth.elb.{}.amazonaws.com/{}",
      region,
      header.kid
  );
    let pem_key = reqwest::blocking::get(&url)
        .map_err(TokenError::ReqwestError)?
        .error_for_status()
        .map_err(TokenError::ReqwestError)?
        .text()
        .map_err(TokenError::ReqwestError)?;

    // 8) Parse the ECDSA public key
    let verifying_key =
        VerifyingKey::from_public_key_pem(&pem_key).map_err(|_| TokenError::PublicKeyParseError)?;

    // 9) Generate the raw signature as (r, s).
    let r_bytes = <[u8; 32]>::try_from(&signature_bytes[..32])
        .map_err(|_| TokenError::SignatureParseError)?;
    let s_bytes = <[u8; 32]>::try_from(&signature_bytes[32..])
        .map_err(|_| TokenError::SignatureParseError)?;

    let sig = P256Signature::from_scalars(r_bytes, s_bytes)
        .map_err(|_| TokenError::SignatureParseError)?;

    // 10) Verify the signature
    verifying_key
        .verify_prehash(&digest, &sig)
        .map_err(|_| TokenError::SignatureVerificationFailed)?;

    Ok((header, claims))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_verify_jwt1() {
        // Test with a valid token
        let valid_token = "add-token-here==";
        let result = verify_jwt(&valid_token, true);
    
        assert_eq!(result.is_ok(), true);
        let (header, claims) = result.unwrap();
        assert_eq!(header, Header {
            kid: "<add-kid-here>".to_string(),
            signer: "arn:aws:elasticloadbalancing:<region>:<account>:loadbalancer/<app>".to_string(),
            exp: 1702884041,
        });
        assert_eq!(claims, Claims {
          sub: "<sub>".to_string(),
          name: "<name>".to_string(),
      });
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment