# /// script |
# requires-python = ">=3.12" |
# dependencies = [ |
# "boto3" |
# ] |
# /// |
from __future__ import annotations |
import argparse |
import base64 |
import hashlib |
import re |
import secrets |
import socket |
import urllib.parse |
import uuid |
import webbrowser |
from http.server import BaseHTTPRequestHandler, HTTPServer |
from typing import Any, Dict, Tuple, TypedDict |
import boto3 |
from botocore.paginate import Paginator |
class Color: |
PURPLE = "\033[95m" |
CYAN = "\033[96m" |
DARKCYAN = "\033[36m" |
BLUE = "\033[94m" |
GREEN = "\033[92m" |
YELLOW = "\033[93m" |
RED = "\033[91m" |
BOLD = "\033[1m" |
UNDERLINE = "\033[4m" |
END = "\033[0m" |
class Role(TypedDict): |
"""Type definition for an AWS role""" |
accountId: str |
accountName: str |
roleName: str |
roleId: str |
class AwsRole(TypedDict): |
"""Type definition for an AWS role""" |
accountId: str |
roleName: str |
class AwsAccount(TypedDict): |
"""Type definition for an AWS account""" |
accountId: str |
accountName: str |
emailAddress: str |
def generate_pkce_verifier() -> str: |
"""Generate a random PKCE verifier string""" |
return secrets.token_urlsafe(64) |
def generate_pkce_challenge(code_verifier: str) -> str: |
"""Generate PKCE challenge from verifier using S256 method""" |
code_challenge_digest = hashlib.sha256(code_verifier.encode("utf-8")).digest() |
code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode("utf-8") |
return code_challenge.rstrip("=") |
def find_available_port() -> int: |
"""Find an available TCP port""" |
try: |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
s.bind(("", 0)) |
return s.getsockname()[1] |
except socket.error as e: |
raise RuntimeError(f"Failed to bind to localhost: {e}") |
class OAuthCallbackHandler(BaseHTTPRequestHandler): |
"""A basic HTTP request handler for receiving a callback as part of the OAuth PKCE flow""" |
def do_GET(self) -> None: |
"""Handle the OAuth callback""" |
self.send_response(200) |
self.send_header("Content-type", "text/html") |
self.end_headers() |
html = """ |
<!DOCTYPE html> |
<html lang="en"> |
<head> |
<meta charset="UTF-8"> |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
<title>aws-cli-config-generator</title> |
<style> |
body { |
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; |
background-color: #f8f9fa; |
display: flex; |
justify-content: center; |
align-items: center; |
height: 100vh; |
margin: 0; |
} |
.container { |
background-color: white; |
padding: 2rem; |
border-radius: 8px; |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
text-align: center; |
max-width: 400px; |
} |
.success-icon { |
color: #28a745; |
font-size: 48px; |
margin-bottom: 1rem; |
} |
h1 { |
color: #212529; |
margin-bottom: 1rem; |
font-size: 1.5rem; |
} |
p { |
color: #6c757d; |
margin: 0; |
} |
</style> |
</head> |
<body> |
<div class="container"> |
<div class="success-icon">✓</div> |
<h1>Authorization complete</h1> |
<p>You can close this window and return to your terminal.</p> |
</div> |
</body> |
</html> |
""" |
self.wfile.write(html.encode("utf-8")) |
# Parse the query parameters |
query = urllib.parse.urlparse(self.path).query |
params = urllib.parse.parse_qs(query) |
# Only set values if they haven't been set before |
if self.server.oauth_code is None: |
self.server.oauth_code = params.get("code", [None])[0] |
if self.server.oauth_state is None: |
self.server.oauth_state = params.get("state", [None])[0] |
def log_message(self, format: str, *args: Any) -> None: |
"""Suppress logging of requests""" |
pass |
def start_temporary_callback_server(port: int) -> HTTPServer: |
"""Start the callback server""" |
server = HTTPServer(("", port), OAuthCallbackHandler) |
server.timeout = 120 # 2 minutes timeout |
server.oauth_code = None |
server.oauth_state = None |
return server |
def get_accounts(sso: Any, token: str, paginator: Paginator) -> list[AwsAccount]: |
"""Fetch all AWS accounts accessible to the user""" |
accounts = [] |
for page in paginator.paginate(accessToken=token): |
accounts.extend(page["accountList"]) |
return accounts |
def get_account_roles( |
sso: Any, token: str, account_id: str, paginator: Paginator |
) -> list[AwsRole]: |
"""Fetch all roles for a given account""" |
roles = [] |
for page in paginator.paginate(accessToken=token, accountId=account_id): |
roles.extend(page["roleList"]) |
return roles |
def get_available_roles(start_url: str, region: str = "eu-west-1") -> list[Role]: |
"""Get available AWS roles for the user""" |
port = find_available_port() |
redirect_uri = f"{port}" |
# Create SSO OIDC client and register |
sso_oidc = boto3.client("sso-oidc", region_name=region) |
client = sso_oidc.register_client( |
clientName="aws-cli-config-generator", |
clientType="public", |
scopes=["sso:account:access"], |
grantTypes=["authorization_code", "refresh_token"], |
issuerUrl=start_url, |
redirectUris=[redirect_uri], |
) |
# Set up PKCE and state |
code_verifier = generate_pkce_verifier() |
code_challenge = generate_pkce_challenge(code_verifier) |
state = uuid.uuid4() |
# Construct and open authorization URL |
auth_params = { |
"response_type": "code", |
"client_id": client["clientId"], |
"redirect_uri": redirect_uri, |
"state": state, |
"code_challenge": code_challenge, |
"code_challenge_method": "S256", |
"scopes": "sso:account:access", |
} |
auth_url = f"https://oidc.{region}.amazonaws.com/authorize?{urllib.parse.urlencode(auth_params)}" |
print("\nOpening browser for authorization...") |
if not webbrowser.open(auth_url): |
print(f"\nPlease open the following URL in your browser:\n{auth_url}") |
server = start_temporary_callback_server(port) |
server.handle_request() |
if server.oauth_code is None: |
raise TimeoutError("Authorization likely timed out") |
if server.oauth_state != str(state): |
raise ValueError("State mismatch - possible CSRF attack") |
try: |
token = sso_oidc.create_token( |
clientId=client["clientId"], |
clientSecret=client["clientSecret"], |
grantType="authorization_code", |
code=server.oauth_code, |
redirectUri=redirect_uri, |
codeVerifier=code_verifier, |
) |
finally: |
server.server_close() |
sso = boto3.client("sso", region_name=region) |
accounts = get_accounts( |
sso, token["accessToken"], sso.get_paginator("list_accounts") |
) |
roles: list[Role] = [] |
for account in accounts: |
account_roles = get_account_roles( |
sso, |
token["accessToken"], |
account["accountId"], |
sso.get_paginator("list_account_roles"), |
) |
roles.extend( |
[ |
Role(**{**role, "accountName": account["accountName"]}) |
for role in account_roles |
] |
) |
return roles |
def sanitize_name(name: str) -> str: |
"""Sanitize name for use in AWS profile""" |
return re.sub(r"[^a-zA-Z0-9]", "-", name) |
def create_config( |
roles: list[Role], |
args: argparse.Namespace, |
profile_role_name_map: Dict[str, str], |
profile_account_name_map: Dict[str, str], |
config_account_region_map: Dict[str, str], |
) -> None: |
"""Write AWS CLI config""" |
session_name = ( |
args.session_name or urllib.parse.urlparse(args.start_url).netloc.split(".")[0] |
) |
print("") |
print("") |
print("Copy the snippet below into '$HOME/.aws/config':") |
print( |
f"{Color.YELLOW}; Run aws_cli_config_generator.py to get up-to-date AWS CLI profiles{Color.END}" |
) |
print( |
f"""{Color.YELLOW}[sso-session {session_name}] |
sso_start_url = {args.start_url} |
sso_region = {args.region} |
sso_registration_scopes = sso:account:access |
{Color.END}""", |
) |
profiles = {} |
for role in roles: |
if args.profile_name_template: |
profile_name = args.profile_name_template |
patterns = { |
"<account-name>": profile_account_name_map.get( |
role["accountName"], sanitize_name(role["accountName"]) |
), |
"<account-id>": role["accountId"], |
"<role-name>": profile_role_name_map.get( |
role["roleName"], role["roleName"] |
), |
} |
for pattern, replacement in patterns.items(): |
profile_name = profile_name.replace(pattern, replacement) |
else: |
profile_name = "-".join( |
[ |
session_name, |
profile_account_name_map.get( |
role["accountName"], sanitize_name(role["accountName"]) |
), |
profile_role_name_map.get(role["roleName"], role["roleName"]), |
] |
) |
profiles[profile_name] = f"""[profile {profile_name.lower()}] |
sso_session = {session_name} |
sso_account_id = {role["accountId"]} |
sso_role_name = {role["roleName"]} |
region = {config_account_region_map.get(role["accountId"], args.region)} |
""" |
for profile_name in sorted(profiles.keys()): |
print(Color.YELLOW + profiles[profile_name] + Color.END) |
def parse_mappings( |
mappings: list[str], |
) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]: |
profile_role_name_map = {} |
profile_account_name_map = {} |
config_account_region_map = {} |
for mapping in mappings: |
try: |
map_type, key, value = mapping.split(":", 2) |
if map_type == "profile-role-name": |
profile_role_name_map[key] = value |
elif map_type == "profile-account-name": |
profile_account_name_map[key] = value |
elif map_type == "config-account-region": |
config_account_region_map[key] = value |
else: |
print( |
f"Warning: Unknown mapping type '{map_type}' in '{mapping}'. Skipping." |
) |
except ValueError: |
print(f"Warning: Invalid mapping format '{mapping}'. Skipping.") |
return profile_role_name_map, profile_account_name_map, config_account_region_map |
def main() -> None: |
parser = argparse.ArgumentParser( |
description="Generate AWS CLI configuration for AWS SSO roles" |
) |
parser.add_argument("--session-name", help="An optional name of the SSO session") |
parser.add_argument( |
"--profile-name-template", |
help="An optional template to use when naming AWS CLI profiles", |
) |
parser.add_argument( |
"--start-url", |
required=True, |
help="The start URL of the AWS IAM Identity Center instance", |
) |
parser.add_argument( |
"--region", |
default="eu-west-1", |
help="The region of the AWS IAM Identity Center instance", |
) |
parser.add_argument( |
"--mapping", |
nargs="+", |
help="Optional mappings in the format 'type:key:value'. Types: profile-role-name, profile-account-name, config-account-region", |
metavar="TYPE:KEY:VALUE", |
) |
args = parser.parse_args() |
global_role_map, account_name_map, account_region_map = parse_mappings( |
args.mapping or [] |
) |
try: |
roles = get_available_roles(args.start_url, args.region) |
create_config( |
roles, args, global_role_map, account_name_map, account_region_map |
) |
except TimeoutError as e: |
print(f"\nError: {e}") |
exit(1) |
except Exception as e: |
print(f"\nUnexpected error: {e}") |
exit(1) |
if __name__ == "__main__": |
main() |