Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save remelpugh/73adaa98928c7b08f24b0c0f9cea0ba9 to your computer and use it in GitHub Desktop.
Save remelpugh/73adaa98928c7b08f24b0c0f9cea0ba9 to your computer and use it in GitHub Desktop.
Using Azure Key Vault keys for signing and encrypting JSON Web Tokens
var vaultUri = new Uri("https://your-key-vault.vault.azure.net/");
var credential = new AzureCliCredential(new AzureCliCredentialOptions
{
TenantId = "your-aad-tenant-id"
});
var keyClient = new KeyClient(vaultUri, credential);
var cryptoProviderFactory = new CryptoProviderFactory();
cryptoProviderFactory.CustomCryptoProvider = new KeyVaultCryptoProvider(keyClient);
var signingKey = await keyClient.GetKeyAsync("TestSigningKey", "c1d4752f020b4a77a9c899901db7c7cd");
var signingRsaKey = new KeyVaultRsaSecurityKey(signingKey)
{
CryptoProviderFactory = cryptoProviderFactory
};
var signingCredentials = new SigningCredentials(signingRsaKey, SecurityAlgorithms.RsaSha256);
var encryptionKey = await keyClient.GetKeyAsync("TestEncryptionKey", "b073d79dcaa74d7d9c7588a475b4fd91");
var encryptionRsaKey = new KeyVaultRsaSecurityKey(encryptionKey)
{
CryptoProviderFactory = cryptoProviderFactory
};
var encryptingCredentials = new EncryptingCredentials(encryptionRsaKey, SecurityAlgorithms.RsaOAEP, SecurityAlgorithms.Aes128CbcHmacSha256);
var handler = new JsonWebTokenHandler();
var encryptedToken = handler.CreateToken(
JsonConvert.SerializeObject(new
{
sub = "test-user-id",
aud = "TestApp",
iss = "https://zure.com",
iat = (long)(DateTime.UtcNow - DateTime.UnixEpoch).TotalSeconds,
nbf = (long)(DateTime.UtcNow - DateTime.UnixEpoch).TotalSeconds,
exp = (long)(DateTime.UtcNow.AddDays(1) - DateTime.UnixEpoch).TotalSeconds,
}),
signingCredentials,
encryptingCredentials);
var validationResult = await handler.ValidateTokenAsync(encryptedToken, new TokenValidationParameters
{
IssuerSigningKeys = new List<SecurityKey>
{
signingRsaKey
},
TokenDecryptionKeys = new List<SecurityKey>
{
encryptionRsaKey
},
TryAllIssuerSigningKeys = false,
ValidAudience = "TestApp",
ValidIssuer = "https://zure.com",
ClockSkew = TimeSpan.Zero,
ValidAlgorithms = new List<string>
{
SecurityAlgorithms.RsaSha256,
SecurityAlgorithms.Aes128CbcHmacSha256,
},
ValidateAudience = true,
ValidateIssuer = true,
ValidateIssuerSigningKey = true,
ValidateLifetime = true,
});
bool isValid = validationResult.IsValid;
if (!isValid)
{
// Check validationResult.Exception
}
IDictionary<string, object> claims = validationResult.Claims;
public class KeyVaultCryptoProvider : ICryptoProvider
{
private readonly KeyClient _keyClient;
public KeyVaultCryptoProvider(KeyClient keyClient)
{
_keyClient = keyClient;
}
public bool IsSupportedAlgorithm(string algorithm, params object[] args)
{
if (algorithm == SecurityAlgorithms.Aes128CbcHmacSha256 && args.Length > 0 && args[0] is SymmetricSecurityKey)
{
return true;
}
if (algorithm == SecurityAlgorithms.RsaOAEP)
{
return true;
}
if (algorithm == SecurityAlgorithms.RsaSha256 || algorithm == SecurityAlgorithms.RsaSha384 || algorithm == SecurityAlgorithms.RsaSha512)
{
return true;
}
return false;
}
public object Create(string algorithm, params object[] args)
{
// The framework classes always call IsSupportedAlgorithm first.
// So we can expect algorithm and args to have sensible values here.
if (algorithm == SecurityAlgorithms.Aes128CbcHmacSha256
&& args.Length > 0
&& args[0] is SymmetricSecurityKey symmetricKey)
{
return new AuthenticatedEncryptionProvider(symmetricKey, algorithm);
}
if (args.Length > 0 && args[0] is KeyVaultRsaSecurityKey rsaKey)
{
if (algorithm == SecurityAlgorithms.RsaOAEP)
{
//var willUnwrap = (bool)args[1];
return new KeyVaultKeyWrapProvider(GetCryptographyClient(rsaKey), rsaKey, algorithm);
}
if (algorithm == SecurityAlgorithms.RsaSha256 || algorithm == SecurityAlgorithms.RsaSha384 || algorithm == SecurityAlgorithms.RsaSha512)
{
//var willCreateSignatures = (bool)args[1];
return new KeyVaultKeySignatureProvider(GetCryptographyClient(rsaKey), rsaKey, algorithm);
}
}
throw new ArgumentException($"Unsupported algorithm: {algorithm}, or invalid arguments given", nameof(algorithm));
}
public void Release(object cryptoInstance)
{
}
private CryptographyClient GetCryptographyClient(KeyVaultRsaSecurityKey key)
{
return _keyClient.GetCryptographyClient(key.KeyName, key.KeyVersion);
}
}
public class KeyVaultKeySignatureProvider : SignatureProvider
{
private readonly CryptographyClient _cryptographyClient;
public KeyVaultKeySignatureProvider(CryptographyClient cryptographyClient, KeyVaultRsaSecurityKey key, string algorithm)
: base(key, algorithm)
{
_cryptographyClient = cryptographyClient;
}
public override byte[] Sign(byte[] input)
{
if (input == null || input.Length == 0)
{
throw new ArgumentNullException(nameof(input));
}
var result = _cryptographyClient.SignData(GetKeyVaultAlgorithm(base.Algorithm), input);
return result.Signature;
}
public override bool Verify(byte[] input, byte[] signature)
{
if (input == null || input.Length == 0)
{
throw new ArgumentNullException(nameof(input));
}
if (signature == null || signature.Length == 0)
{
throw new ArgumentNullException(nameof(signature));
}
// Use the RSA object directly since we already have the public key
var key = (KeyVaultRsaSecurityKey)base.Key;
using var rsa = key.Key.Key.ToRSA();
var isValid = rsa.VerifyData(input, signature, GetHashAlgorithm(base.Algorithm), RSASignaturePadding.Pkcs1);
return isValid;
// With Key Vault:
//var result = _cryptographyClient.VerifyData(GetKeyVaultAlgorithm(base.Algorithm), input, signature);
//return result.IsValid;
}
public override bool Verify(byte[] input, int inputOffset, int inputLength, byte[] signature, int signatureOffset, int signatureLength)
{
if (input == null || input.Length == 0)
{
throw new ArgumentNullException(nameof(input));
}
if (signature == null || signature.Length == 0)
{
throw new ArgumentNullException(nameof(signature));
}
if (inputOffset < 0)
{
throw new ArgumentException("inputOffset must be greater than 0", nameof(inputOffset));
}
if (inputLength < 1)
{
throw new ArgumentException("inputLength must be greater than 1", nameof(inputLength));
}
if (inputOffset + inputLength > input.Length)
{
throw new ArgumentException("inputOffset + inputLength must be greater than input array length");
}
if (signatureOffset < 0)
{
throw new ArgumentException("signatureOffset must be greater than 0", nameof(signatureOffset));
}
if (signatureLength < 1)
{
throw new ArgumentException("signatureLength must be greater than 1", nameof(signatureLength));
}
if (signatureOffset + signatureLength > signature.Length)
{
throw new ArgumentException("signatureOffset + signatureLength must be greater than signature array length");
}
// Basically the input or signature array could contain a bunch of zeroes that we don't want
// in the signature calculation, as that would affect the result.
// In testing, inputLength < input.Length and signatureLength == signature.Length.
// The offsets were zero in both cases.
// This is only needed with Key Vault:
//byte[] actualInput;
//if (input.Length == inputLength)
//{
// actualInput = input;
//}
//else
//{
// var temp = new byte[inputLength];
// Array.Copy(input, inputOffset, temp, 0, inputLength);
// actualInput = temp;
//}
byte[] actualSignature;
if (signature.Length == signatureLength)
{
actualSignature = signature;
}
else
{
var temp = new byte[signatureLength];
Array.Copy(signature, signatureOffset, temp, 0, signatureLength);
actualSignature = temp;
}
// Use the RSA object directly since we already have the public key
var key = (KeyVaultRsaSecurityKey)base.Key;
using var rsa = key.Key.Key.ToRSA();
var isValid = rsa.VerifyData(input, inputOffset, inputLength, actualSignature, GetHashAlgorithm(base.Algorithm), RSASignaturePadding.Pkcs1);
return isValid;
// With Key Vault, call the other overload:
//return Verify(actualInput, actualSignature);
}
protected override void Dispose(bool disposing)
{
}
private static HashAlgorithmName GetHashAlgorithm(string algorithm)
{
return algorithm switch
{
SecurityAlgorithms.RsaSha256 => HashAlgorithmName.SHA256,
SecurityAlgorithms.RsaSha384 => HashAlgorithmName.SHA384,
SecurityAlgorithms.RsaSha512 => HashAlgorithmName.SHA512,
_ => throw new NotImplementedException(),
};
}
private static SignatureAlgorithm GetKeyVaultAlgorithm(string algorithm)
{
return algorithm switch
{
SecurityAlgorithms.RsaSha256 => SignatureAlgorithm.RS256,
SecurityAlgorithms.RsaSha384 => SignatureAlgorithm.RS384,
SecurityAlgorithms.RsaSha512 => SignatureAlgorithm.RS512,
_ => throw new NotImplementedException(),
};
}
}
public class KeyVaultKeyWrapProvider : KeyWrapProvider
{
private readonly CryptographyClient _cryptographyClient;
public KeyVaultKeyWrapProvider(CryptographyClient cryptographyClient, KeyVaultRsaSecurityKey key, string algorithm)
{
_cryptographyClient = cryptographyClient;
Key = key;
Algorithm = algorithm;
}
public override SecurityKey Key { get; }
public override string Algorithm { get; }
public override string Context { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
public override byte[] WrapKey(byte[] keyBytes)
{
if (keyBytes == null || keyBytes.Length == 0)
{
throw new ArgumentNullException(nameof(keyBytes));
}
// Use the RSA object directly since we already have the public key
var key = (KeyVaultRsaSecurityKey)Key;
using var rsa = key.Key.Key.ToRSA();
var result = rsa.Encrypt(keyBytes, GetRsaEncryptionPadding(Algorithm));
return result;
// With Key Vault:
//var result = _cryptographyClient.WrapKey(GetAlgorithm(Algorithm), keyBytes);
//return result.EncryptedKey;
}
public override byte[] UnwrapKey(byte[] keyBytes)
{
if (keyBytes == null || keyBytes.Length == 0)
{
throw new ArgumentNullException(nameof(keyBytes));
}
var result = _cryptographyClient.UnwrapKey(GetAlgorithm(Algorithm), keyBytes);
return result.Key;
}
protected override void Dispose(bool disposing)
{
}
private static RSAEncryptionPadding GetRsaEncryptionPadding(string algorithm)
{
return algorithm switch
{
SecurityAlgorithms.RsaOAEP => RSAEncryptionPadding.OaepSHA1,
_ => throw new NotImplementedException(),
};
}
private static KeyWrapAlgorithm GetAlgorithm(string algorithm)
{
return algorithm switch
{
SecurityAlgorithms.RsaOAEP => KeyWrapAlgorithm.RsaOaep,
_ => throw new NotImplementedException(),
};
}
}
public class KeyVaultRsaSecurityKey : AsymmetricSecurityKey
{
public KeyVaultRsaSecurityKey(KeyVaultKey key)
{
Key = key;
KeyId = GetKeyId(key);
}
public KeyVaultKey Key { get; }
public override int KeySize => new BitArray(Key.Key.N).Length;
public override string KeyId { get; set; }
public string KeyName => Key.Properties.Name;
public string KeyVersion => Key.Properties.Version;
// In our case we always have a private key in Key Vault
// (could check supported key operations)
[Obsolete]
public override bool HasPrivateKey => true;
public override PrivateKeyStatus PrivateKeyStatus => PrivateKeyStatus.Exists;
private static string GetKeyId(KeyVaultKey key)
{
using var rsa = key.Key.ToRSA();
var rsaKey = new RsaSecurityKey(rsa);
var thumbprint = rsaKey.ComputeJwkThumbprint();
return Base64UrlEncoder.Encode(thumbprint);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment