Skip to content

Instantly share code, notes, and snippets.

@alexanderankin
Last active October 23, 2023 14:55
Show Gist options
  • Save alexanderankin/b34a2a2ff12bb49404096c0e64668aa3 to your computer and use it in GitHub Desktop.
Save alexanderankin/b34a2a2ff12bb49404096c0e64668aa3 to your computer and use it in GitHub Desktop.

Client Assertion

Explains how to do Azure Authentication with OAuth "client credentials" grant with a Client Assertion.

pip install python-jose jwt cryptography

references:

creating a cert:

  • openssl req -days 365 -new -newkey rsa:2048 -sha256 -nodes -x509 -keyout demo.key -out demo.crt -subj "/CN=foo"
  • openssl req -days 365 -new -newkey rsa:4096 -sha256 -nodes -x509 -keyout $name.key -out $name.crt -subj "/CN=my-app/O=my-org"

appending cert:

  • az ad sp create-for-rbac --role Reader --scopes /subscriptions/$subscription --name demo
  • az ad sp credential reset --id $appId --cert @demo.crt --append (not append = replace)
  • az ad sp credential list --id $appId --cert

todo remove credentials/client secret credentials

import argparse
import dataclasses
import json
import pathlib
import sys
import typing
import urllib.error
import urllib.parse
import urllib.request
KNOWN_TENANTS = {
'personal': '50e38f06-4334-4d1d-8d3b-346dd52186af',
}
# noinspection DuplicatedCode
IgnorePropertiesType = typing.TypeVar('IgnorePropertiesType')
def ignore_properties(cls: typing.Type[IgnorePropertiesType], dict_: any) -> IgnorePropertiesType:
class_fields = [f.name for f in dataclasses.fields(cls)]
filtered = {k: v for k, v in dict_.items() if k in class_fields}
return cls(**filtered)
def fetch(req: urllib.request.Request) -> str:
try:
with urllib.request.urlopen(req) as response:
return response.read().decode('utf-8')
except urllib.error.HTTPError as e:
raise Exception(f'http error: {e.code} - {e.read()}') from e
@dataclasses.dataclass(frozen=True)
class Client:
tenant_id: str
client_id: str
scopes: typing.List[str] = None
def client_secret_method(client: Client, secret: str):
url = f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/token'
data = {
'grant_type': 'client_credentials',
'client_id': client.client_id,
'client_secret': secret,
'resource': 'https://cognitiveservices.azure.com',
}
request = urllib.request.Request(
url,
data=urllib.parse.urlencode(data).encode('utf-8'),
headers={'content-type': 'application/x-www-form-urlencoded'},
method='POST',
)
response_data = fetch(request)
data = json.loads(response_data)
return data['access_token']
def _make_client_assertion(certificate_key_path: pathlib.Path, client: Client, thumb_print: str):
import uuid
import datetime
now = datetime.datetime.now()
now_ts = now.timestamp()
jwt_payload = {
'aud': f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/v2.0/token',
'jti': str(uuid.uuid4()),
'iss': client.client_id,
'sub': client.client_id,
'exp': (now + datetime.timedelta(hours=1)).timestamp(),
'iat': now_ts,
'nbf': now_ts,
}
import base64
thumb = base64.urlsafe_b64encode(bytes.fromhex(thumb_print)).decode()
jwt_thumb_header = {'x5t': thumb}
import jose.jwt
import jose.constants
jwt_token = jose.jwt.encode(claims=jwt_payload,
key=certificate_key_path.read_text(),
algorithm=jose.constants.ALGORITHMS.RS256,
headers=jwt_thumb_header)
return jwt_token
def _use_client_assertion(client: Client, assertion: str):
data = fetch(urllib.request.Request(url=f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/v2.0/token',
data=urllib.parse.urlencode({
'grant_type': 'client_credentials',
'scope': ' '.join(client.scopes),
'client_assertion_type':
'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
'client_assertion': assertion
}).encode('utf-8')))
return json.loads(data)['access_token']
def client_certificate_method(client: Client,
certificate_key_path: pathlib.Path,
certificate_path: pathlib.Path):
cert_data = certificate_path.read_bytes()
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.serialization
import cryptography.x509
der_cert = cryptography.x509.load_pem_x509_certificate(cert_data,
cryptography.hazmat.backends.default_backend()) \
.public_bytes(encoding=cryptography.hazmat.primitives.serialization.Encoding.DER)
import hashlib
sha1 = hashlib.sha1(der_cert).hexdigest()
return client_thumbprint_method(client, certificate_key_path, sha1)
def client_thumbprint_method(client: Client,
certificate_key_path: pathlib.Path,
thumbprint: str):
assertion = _make_client_assertion(certificate_key_path, client, thumbprint)
return _use_client_assertion(client, assertion)
def msal_method(client: Client,
certificate_key_path: pathlib.Path,
thumbprint: str):
import msal
app = msal.ConfidentialClientApplication(
client.client_id,
authority=f'https://login.microsoftonline.com/{client.tenant_id}',
client_credential={
'thumbprint': thumbprint,
'private_key': certificate_key_path.read_text(),
},
)
return app.acquire_token_for_client(scopes=client.scopes)['access_token']
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--dry-run', action='store_true', help='print args + exit')
parser.add_argument('-t', '--tenant', dest='tenant_id', help='tenant id', required=True)
parser.add_argument('-c', '--client', dest='client_id', help='client id', required=True)
parser.add_argument('--scope', type=str, dest='scopes',
default='https://cognitiveservices.azure.com/.default',
help='OAuth scope (space separated list) - default openai (todo revisit)')
sp = parser.add_subparsers(dest='method', required=True, help='Auth method')
secret_parser = sp.add_parser('secret',
help='use client_id and client_secret in client_credentials grant')
secret_parser.add_argument('-cs:e', '--client-secret-env', help='the env var with the client secret')
secret_parser.add_argument('-cs', '--client-secret', help='the value of the client secret')
cert_parser = sp.add_parser('cert', help='client cert (cert.poem + cert.key')
cert_parser.add_argument('-ck', '--cert-key', help='path to certificate key file')
cert_parser.add_argument('-ck:e', '--cert-key-env', help='env var with path to certificate key file')
cert_parser.add_argument('-c', '--cert', help='path to public certificate file')
cert_parser.add_argument('-c:e', '--cert-env', help='env var with path to public certificate file')
thumbprint_parser = sp.add_parser('thumbprint', help='client cert (cert.key ONLY)')
msal_parser = sp.add_parser('msal', help='client cert (cert.key ONLY)')
for p in [thumbprint_parser, msal_parser]:
p.add_argument('-ck', '--cert-key', help='path to certificate key file')
p.add_argument('-ck:e', '--cert-key-env', help='env var with path to certificate key file')
p.add_argument('-t', '--thumbprint', help='thumbprint (in lieu of public key)')
args = parser.parse_args()
args_dict = vars(args)
args_dict['scopes'] = [] if 'scopes' not in args_dict else args_dict['scopes'].split()
args_dict['tenant_id'] = KNOWN_TENANTS.get(args_dict['tenant_id'], args_dict['tenant_id'])
client = ignore_properties(Client, args_dict)
if args.dry_run:
import pprint
pprint.pprint(args)
pprint.pprint(args_dict)
pprint.pprint(client)
exit(0)
def or_read_env(value, env_var):
if value:
return value
import os
return os.getenv(env_var)
def ensure_exists(path: str) -> pathlib.Path:
p = pathlib.Path(path)
if p.exists(): return p # noqa
raise Exception(f'interpreting value as path but it does not exist: {p}')
result = {
'secret': lambda: client_secret_method(client, or_read_env(args.client_secret, args.client_secret_env)),
'cert': lambda: client_certificate_method(client,
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)),
ensure_exists(or_read_env(args.cert, args.cert_env))),
'thumbprint': lambda: client_thumbprint_method(client,
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)),
args.thumbprint),
'msal': lambda: msal_method(client,
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)),
args.thumbprint),
}[args.method]()
print('result:', file=sys.stderr)
print(result)
if __name__ == '__main__':
main()
import argparse
import base64
import dataclasses
import datetime
import pathlib
import urllib.error
import urllib.parse
import urllib.request
import uuid
import jose.jwt
from jose.constants import ALGORITHMS
import thumbprint_util
@dataclasses.dataclass
class Credentials:
tenant_id: str = None
client_id: str = None
certificate: str = None
certificate_key: str = None
certificate_password: str = None
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant-id')
parser.add_argument('-u', '--client-id')
parser.add_argument('-c', '--certificate', help='path to file')
parser.add_argument('-k', '--certificate-key', help='path to file')
parser.add_argument('--certificate-password', required=False)
args = parser.parse_args()
credentials = Credentials(**vars(args))
token = create_jwt_token(credentials)
result = submit_jwt_token(token, credentials)
print(result)
def create_jwt_token(credentials: Credentials):
tenant = credentials.tenant_id
client_id = credentials.client_id
pkcs12_file_path = credentials.certificate_key
pkcs12_password = None
# Load the PKCS12 certificate and extract the private key
private_key = pathlib.Path(pkcs12_file_path).read_text()
# with open(pkcs12_file_path, 'rb') as f:
# pkcs12_data = f.read()
# private_key = serialization.load_pem_private_key(pkcs12_data, pkcs12_password.encode())
# Define the JWT payload
payload = {
"aud": f"https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token",
"exp": (datetime.datetime.now() + datetime.timedelta(hours=1)).timestamp(),
"iat": (datetime.datetime.now()).timestamp(),
"jti": str(uuid.uuid4()),
"iss": client_id, # Client ID
"sub": client_id, # Client ID
"nbf": (datetime.datetime.now()).timestamp(),
}
# Generate the JWT
thumbprint = thumbprint_util.thumbprint(pathlib.Path(credentials.certificate).read_text())
x5t = base64.urlsafe_b64encode(bytes.fromhex(thumbprint)).decode().removesuffix('=')
jwt_token = jose.jwt.encode(claims=payload, key=private_key, algorithm=ALGORITHMS.RS256, headers={
"x5t": x5t
})
return jwt_token
def submit_jwt_token(token, credentials):
tenant = credentials.tenant_id
url = f"https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token"
body = urllib.parse.urlencode({
'scope': 'https://graph.microsoft.com/.default',
'client_id': credentials.client_id,
'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
'client_assertion': token,
'grant_type': 'client_credentials'
})
request = urllib.request.Request(url=f"https://")
try:
with urllib.request.urlopen(url, body.encode('utf-8')) as response:
return response.read()
except urllib.error.HTTPError as e:
raise Exception(f'HTTP Error: {e.code}: {e.read()}') from e
if __name__ == '__main__':
main()
import argparse
import base64
import dataclasses
import datetime
import pathlib
import urllib.error
import urllib.parse
import urllib.request
import uuid
import jwt
import thumbprint_util
@dataclasses.dataclass
class Credentials:
tenant_id: str = None
client_id: str = None
certificate: str = None
certificate_key: str = None
certificate_password: str = None
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant-id')
parser.add_argument('-u', '--client-id')
parser.add_argument('-c', '--certificate', help='path to file')
parser.add_argument('-k', '--certificate-key', help='path to file')
parser.add_argument('--certificate-password', required=False)
args = parser.parse_args()
credentials = Credentials(**vars(args))
token = create_jwt_token(credentials)
result = submit_jwt_token(token, credentials)
print(result)
def create_jwt_token(credentials: Credentials):
tenant = credentials.tenant_id
client_id = credentials.client_id
pkcs12_file_path = credentials.certificate_key
pkcs12_password = None
# Load the PKCS12 certificate and extract the private key
private_key = pathlib.Path(pkcs12_file_path).read_text()
# with open(pkcs12_file_path, 'rb') as f:
# pkcs12_data = f.read()
# private_key = serialization.load_pem_private_key(pkcs12_data, pkcs12_password.encode())
# Define the JWT payload
payload = {
"aud": f"https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token",
"exp": (datetime.datetime.now() + datetime.timedelta(hours=1)).timestamp(),
"iat": (datetime.datetime.now()).timestamp(),
"jti": str(uuid.uuid4()),
"iss": client_id, # Client ID
"sub": client_id, # Client ID
"nbf": (datetime.datetime.now()).timestamp(),
}
# Generate the JWT
thumbprint = thumbprint_util.thumbprint(pathlib.Path(credentials.certificate).read_text())
x5t = base64.urlsafe_b64encode(bytes.fromhex(thumbprint)).decode().removesuffix('=')
key = jwt.jwk_from_pem(private_key.encode())
jwt_token = jwt.JWT().encode(payload, key, alg="RS256", optional_headers={
"x5t": x5t
})
return jwt_token
def submit_jwt_token(token, credentials):
tenant = credentials.tenant_id
url = f"https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token"
body = urllib.parse.urlencode({
'scope': 'https://graph.microsoft.com/.default',
'client_id': credentials.client_id,
'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
'client_assertion': token,
'grant_type': 'client_credentials'
})
request = urllib.request.Request(url=f"https://")
try:
with urllib.request.urlopen(url, body.encode('utf-8')) as response:
return response.read()
except urllib.error.HTTPError as e:
raise Exception(f'HTTP Error: {e.code}: {e.read()}') from e
if __name__ == '__main__':
main()
import hashlib
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.x509 import load_pem_x509_certificate
def thumbprint(cert_data: str):
"""
:param cert_data: x509 bytes
:return:
"""
# Parse the X.509 certificate
pb = load_pem_x509_certificate(cert_data.encode(), default_backend()) \
.public_bytes(encoding=Encoding.DER)
return hashlib.sha1(pb).hexdigest()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment