|
import argparse |
|
import dataclasses |
|
import json |
|
import pathlib |
|
import sys |
|
import typing |
|
import urllib.error |
|
import urllib.parse |
|
import urllib.request |
|
|
|
KNOWN_TENANTS = { |
|
'personal': '50e38f06-4334-4d1d-8d3b-346dd52186af', |
|
} |
|
|
|
# noinspection DuplicatedCode |
|
IgnorePropertiesType = typing.TypeVar('IgnorePropertiesType') |
|
|
|
|
|
def ignore_properties(cls: typing.Type[IgnorePropertiesType], dict_: any) -> IgnorePropertiesType: |
|
class_fields = [f.name for f in dataclasses.fields(cls)] |
|
filtered = {k: v for k, v in dict_.items() if k in class_fields} |
|
return cls(**filtered) |
|
|
|
|
|
def fetch(req: urllib.request.Request) -> str: |
|
try: |
|
with urllib.request.urlopen(req) as response: |
|
return response.read().decode('utf-8') |
|
except urllib.error.HTTPError as e: |
|
raise Exception(f'http error: {e.code} - {e.read()}') from e |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class Client: |
|
tenant_id: str |
|
client_id: str |
|
scopes: typing.List[str] = None |
|
|
|
|
|
def client_secret_method(client: Client, secret: str): |
|
url = f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/token' |
|
data = { |
|
'grant_type': 'client_credentials', |
|
'client_id': client.client_id, |
|
'client_secret': secret, |
|
'resource': 'https://cognitiveservices.azure.com', |
|
} |
|
request = urllib.request.Request( |
|
url, |
|
data=urllib.parse.urlencode(data).encode('utf-8'), |
|
headers={'content-type': 'application/x-www-form-urlencoded'}, |
|
method='POST', |
|
) |
|
response_data = fetch(request) |
|
data = json.loads(response_data) |
|
return data['access_token'] |
|
|
|
|
|
def _make_client_assertion(certificate_key_path: pathlib.Path, client: Client, thumb_print: str): |
|
import uuid |
|
import datetime |
|
now = datetime.datetime.now() |
|
now_ts = now.timestamp() |
|
jwt_payload = { |
|
'aud': f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/v2.0/token', |
|
'jti': str(uuid.uuid4()), |
|
'iss': client.client_id, |
|
'sub': client.client_id, |
|
'exp': (now + datetime.timedelta(hours=1)).timestamp(), |
|
'iat': now_ts, |
|
'nbf': now_ts, |
|
} |
|
|
|
import base64 |
|
|
|
thumb = base64.urlsafe_b64encode(bytes.fromhex(thumb_print)).decode() |
|
jwt_thumb_header = {'x5t': thumb} |
|
|
|
import jose.jwt |
|
import jose.constants |
|
|
|
jwt_token = jose.jwt.encode(claims=jwt_payload, |
|
key=certificate_key_path.read_text(), |
|
algorithm=jose.constants.ALGORITHMS.RS256, |
|
headers=jwt_thumb_header) |
|
return jwt_token |
|
|
|
|
|
def _use_client_assertion(client: Client, assertion: str): |
|
data = fetch(urllib.request.Request(url=f'https://login.microsoftonline.com/{client.tenant_id}/oauth2/v2.0/token', |
|
data=urllib.parse.urlencode({ |
|
'grant_type': 'client_credentials', |
|
'scope': ' '.join(client.scopes), |
|
'client_assertion_type': |
|
'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', |
|
'client_assertion': assertion |
|
}).encode('utf-8'))) |
|
return json.loads(data)['access_token'] |
|
|
|
|
|
def client_certificate_method(client: Client, |
|
certificate_key_path: pathlib.Path, |
|
certificate_path: pathlib.Path): |
|
cert_data = certificate_path.read_bytes() |
|
|
|
import cryptography.hazmat.backends |
|
import cryptography.hazmat.primitives.serialization |
|
import cryptography.x509 |
|
|
|
der_cert = cryptography.x509.load_pem_x509_certificate(cert_data, |
|
cryptography.hazmat.backends.default_backend()) \ |
|
.public_bytes(encoding=cryptography.hazmat.primitives.serialization.Encoding.DER) |
|
|
|
import hashlib |
|
|
|
sha1 = hashlib.sha1(der_cert).hexdigest() |
|
|
|
return client_thumbprint_method(client, certificate_key_path, sha1) |
|
|
|
|
|
def client_thumbprint_method(client: Client, |
|
certificate_key_path: pathlib.Path, |
|
thumbprint: str): |
|
assertion = _make_client_assertion(certificate_key_path, client, thumbprint) |
|
return _use_client_assertion(client, assertion) |
|
|
|
|
|
def msal_method(client: Client, |
|
certificate_key_path: pathlib.Path, |
|
thumbprint: str): |
|
import msal |
|
|
|
app = msal.ConfidentialClientApplication( |
|
client.client_id, |
|
authority=f'https://login.microsoftonline.com/{client.tenant_id}', |
|
client_credential={ |
|
'thumbprint': thumbprint, |
|
'private_key': certificate_key_path.read_text(), |
|
}, |
|
) |
|
return app.acquire_token_for_client(scopes=client.scopes)['access_token'] |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-n', '--dry-run', action='store_true', help='print args + exit') |
|
parser.add_argument('-t', '--tenant', dest='tenant_id', help='tenant id', required=True) |
|
parser.add_argument('-c', '--client', dest='client_id', help='client id', required=True) |
|
parser.add_argument('--scope', type=str, dest='scopes', |
|
default='https://cognitiveservices.azure.com/.default', |
|
help='OAuth scope (space separated list) - default openai (todo revisit)') |
|
|
|
sp = parser.add_subparsers(dest='method', required=True, help='Auth method') |
|
|
|
secret_parser = sp.add_parser('secret', |
|
help='use client_id and client_secret in client_credentials grant') |
|
secret_parser.add_argument('-cs:e', '--client-secret-env', help='the env var with the client secret') |
|
secret_parser.add_argument('-cs', '--client-secret', help='the value of the client secret') |
|
|
|
cert_parser = sp.add_parser('cert', help='client cert (cert.poem + cert.key') |
|
cert_parser.add_argument('-ck', '--cert-key', help='path to certificate key file') |
|
cert_parser.add_argument('-ck:e', '--cert-key-env', help='env var with path to certificate key file') |
|
cert_parser.add_argument('-c', '--cert', help='path to public certificate file') |
|
cert_parser.add_argument('-c:e', '--cert-env', help='env var with path to public certificate file') |
|
|
|
thumbprint_parser = sp.add_parser('thumbprint', help='client cert (cert.key ONLY)') |
|
msal_parser = sp.add_parser('msal', help='client cert (cert.key ONLY)') |
|
|
|
for p in [thumbprint_parser, msal_parser]: |
|
p.add_argument('-ck', '--cert-key', help='path to certificate key file') |
|
p.add_argument('-ck:e', '--cert-key-env', help='env var with path to certificate key file') |
|
p.add_argument('-t', '--thumbprint', help='thumbprint (in lieu of public key)') |
|
|
|
args = parser.parse_args() |
|
args_dict = vars(args) |
|
args_dict['scopes'] = [] if 'scopes' not in args_dict else args_dict['scopes'].split() |
|
args_dict['tenant_id'] = KNOWN_TENANTS.get(args_dict['tenant_id'], args_dict['tenant_id']) |
|
|
|
client = ignore_properties(Client, args_dict) |
|
|
|
if args.dry_run: |
|
import pprint |
|
pprint.pprint(args) |
|
pprint.pprint(args_dict) |
|
pprint.pprint(client) |
|
exit(0) |
|
|
|
def or_read_env(value, env_var): |
|
if value: |
|
return value |
|
import os |
|
return os.getenv(env_var) |
|
|
|
def ensure_exists(path: str) -> pathlib.Path: |
|
p = pathlib.Path(path) |
|
if p.exists(): return p # noqa |
|
raise Exception(f'interpreting value as path but it does not exist: {p}') |
|
|
|
result = { |
|
'secret': lambda: client_secret_method(client, or_read_env(args.client_secret, args.client_secret_env)), |
|
|
|
'cert': lambda: client_certificate_method(client, |
|
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)), |
|
ensure_exists(or_read_env(args.cert, args.cert_env))), |
|
|
|
'thumbprint': lambda: client_thumbprint_method(client, |
|
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)), |
|
args.thumbprint), |
|
|
|
'msal': lambda: msal_method(client, |
|
ensure_exists(or_read_env(args.cert_key, args.cert_key_env)), |
|
args.thumbprint), |
|
}[args.method]() |
|
|
|
print('result:', file=sys.stderr) |
|
print(result) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |