|
# /// script |
|
# requires-python = ">=3.12" |
|
# dependencies = [ |
|
# "boto3" |
|
# ] |
|
# /// |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import base64 |
|
import hashlib |
|
import re |
|
import secrets |
|
import socket |
|
import urllib.parse |
|
import uuid |
|
import webbrowser |
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
from typing import Any, Dict, Tuple, TypedDict |
|
|
|
import boto3 |
|
from botocore.paginate import Paginator |
|
|
|
|
|
class Color: |
|
PURPLE = "\033[95m" |
|
CYAN = "\033[96m" |
|
DARKCYAN = "\033[36m" |
|
BLUE = "\033[94m" |
|
GREEN = "\033[92m" |
|
YELLOW = "\033[93m" |
|
RED = "\033[91m" |
|
BOLD = "\033[1m" |
|
UNDERLINE = "\033[4m" |
|
END = "\033[0m" |
|
|
|
|
|
class Role(TypedDict): |
|
"""Type definition for an AWS role""" |
|
|
|
accountId: str |
|
accountName: str |
|
roleName: str |
|
roleId: str |
|
|
|
|
|
class AwsRole(TypedDict): |
|
"""Type definition for an AWS role""" |
|
|
|
accountId: str |
|
roleName: str |
|
|
|
|
|
class AwsAccount(TypedDict): |
|
"""Type definition for an AWS account""" |
|
|
|
accountId: str |
|
accountName: str |
|
emailAddress: str |
|
|
|
|
|
def generate_pkce_verifier() -> str: |
|
"""Generate a random PKCE verifier string""" |
|
return secrets.token_urlsafe(64) |
|
|
|
|
|
def generate_pkce_challenge(code_verifier: str) -> str: |
|
"""Generate PKCE challenge from verifier using S256 method""" |
|
code_challenge_digest = hashlib.sha256(code_verifier.encode("utf-8")).digest() |
|
code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode("utf-8") |
|
return code_challenge.rstrip("=") |
|
|
|
|
|
def find_available_port() -> int: |
|
"""Find an available TCP port""" |
|
try: |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
s.bind(("127.0.0.1", 0)) |
|
return s.getsockname()[1] |
|
except socket.error as e: |
|
raise RuntimeError(f"Failed to bind to localhost: {e}") |
|
|
|
|
|
class OAuthCallbackHandler(BaseHTTPRequestHandler): |
|
"""A basic HTTP request handler for receiving a callback as part of the OAuth PKCE flow""" |
|
|
|
def do_GET(self) -> None: |
|
"""Handle the OAuth callback""" |
|
self.send_response(200) |
|
self.send_header("Content-type", "text/html") |
|
self.end_headers() |
|
html = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>aws-cli-config-generator</title> |
|
<style> |
|
body { |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; |
|
background-color: #f8f9fa; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
height: 100vh; |
|
margin: 0; |
|
} |
|
.container { |
|
background-color: white; |
|
padding: 2rem; |
|
border-radius: 8px; |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
text-align: center; |
|
max-width: 400px; |
|
} |
|
.success-icon { |
|
color: #28a745; |
|
font-size: 48px; |
|
margin-bottom: 1rem; |
|
} |
|
h1 { |
|
color: #212529; |
|
margin-bottom: 1rem; |
|
font-size: 1.5rem; |
|
} |
|
p { |
|
color: #6c757d; |
|
margin: 0; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<div class="success-icon">✓</div> |
|
<h1>Authorization complete</h1> |
|
<p>You can close this window and return to your terminal.</p> |
|
</div> |
|
</body> |
|
</html> |
|
""" |
|
self.wfile.write(html.encode("utf-8")) |
|
|
|
# Parse the query parameters |
|
query = urllib.parse.urlparse(self.path).query |
|
params = urllib.parse.parse_qs(query) |
|
|
|
# Only set values if they haven't been set before |
|
if self.server.oauth_code is None: |
|
self.server.oauth_code = params.get("code", [None])[0] |
|
if self.server.oauth_state is None: |
|
self.server.oauth_state = params.get("state", [None])[0] |
|
|
|
def log_message(self, format: str, *args: Any) -> None: |
|
"""Suppress logging of requests""" |
|
pass |
|
|
|
|
|
def start_temporary_callback_server(port: int) -> HTTPServer: |
|
"""Start the callback server""" |
|
server = HTTPServer(("127.0.0.1", port), OAuthCallbackHandler) |
|
server.timeout = 120 # 2 minutes timeout |
|
server.oauth_code = None |
|
server.oauth_state = None |
|
return server |
|
|
|
|
|
def get_accounts(sso: Any, token: str, paginator: Paginator) -> list[AwsAccount]: |
|
"""Fetch all AWS accounts accessible to the user""" |
|
accounts = [] |
|
for page in paginator.paginate(accessToken=token): |
|
accounts.extend(page["accountList"]) |
|
return accounts |
|
|
|
|
|
def get_account_roles( |
|
sso: Any, token: str, account_id: str, paginator: Paginator |
|
) -> list[AwsRole]: |
|
"""Fetch all roles for a given account""" |
|
roles = [] |
|
for page in paginator.paginate(accessToken=token, accountId=account_id): |
|
roles.extend(page["roleList"]) |
|
return roles |
|
|
|
|
|
def get_available_roles(start_url: str, region: str = "eu-west-1") -> list[Role]: |
|
"""Get available AWS roles for the user""" |
|
port = find_available_port() |
|
redirect_uri = f"http://127.0.0.1:{port}" |
|
|
|
# Create SSO OIDC client and register |
|
sso_oidc = boto3.client("sso-oidc", region_name=region) |
|
client = sso_oidc.register_client( |
|
clientName="aws-cli-config-generator", |
|
clientType="public", |
|
scopes=["sso:account:access"], |
|
grantTypes=["authorization_code", "refresh_token"], |
|
issuerUrl=start_url, |
|
redirectUris=[redirect_uri], |
|
) |
|
|
|
# Set up PKCE and state |
|
code_verifier = generate_pkce_verifier() |
|
code_challenge = generate_pkce_challenge(code_verifier) |
|
state = uuid.uuid4() |
|
|
|
# Construct and open authorization URL |
|
auth_params = { |
|
"response_type": "code", |
|
"client_id": client["clientId"], |
|
"redirect_uri": redirect_uri, |
|
"state": state, |
|
"code_challenge": code_challenge, |
|
"code_challenge_method": "S256", |
|
"scopes": "sso:account:access", |
|
} |
|
auth_url = f"https://oidc.{region}.amazonaws.com/authorize?{urllib.parse.urlencode(auth_params)}" |
|
|
|
print("\nOpening browser for authorization...") |
|
if not webbrowser.open(auth_url): |
|
print(f"\nPlease open the following URL in your browser:\n{auth_url}") |
|
|
|
server = start_temporary_callback_server(port) |
|
server.handle_request() |
|
|
|
if server.oauth_code is None: |
|
raise TimeoutError("Authorization likely timed out") |
|
if server.oauth_state != str(state): |
|
raise ValueError("State mismatch - possible CSRF attack") |
|
|
|
try: |
|
token = sso_oidc.create_token( |
|
clientId=client["clientId"], |
|
clientSecret=client["clientSecret"], |
|
grantType="authorization_code", |
|
code=server.oauth_code, |
|
redirectUri=redirect_uri, |
|
codeVerifier=code_verifier, |
|
) |
|
finally: |
|
server.server_close() |
|
|
|
sso = boto3.client("sso", region_name=region) |
|
accounts = get_accounts( |
|
sso, token["accessToken"], sso.get_paginator("list_accounts") |
|
) |
|
|
|
roles: list[Role] = [] |
|
for account in accounts: |
|
account_roles = get_account_roles( |
|
sso, |
|
token["accessToken"], |
|
account["accountId"], |
|
sso.get_paginator("list_account_roles"), |
|
) |
|
roles.extend( |
|
[ |
|
Role(**{**role, "accountName": account["accountName"]}) |
|
for role in account_roles |
|
] |
|
) |
|
|
|
return roles |
|
|
|
|
|
def sanitize_name(name: str) -> str: |
|
"""Sanitize name for use in AWS profile""" |
|
return re.sub(r"[^a-zA-Z0-9]", "-", name) |
|
|
|
|
|
def create_config( |
|
roles: list[Role], |
|
args: argparse.Namespace, |
|
profile_role_name_map: Dict[str, str], |
|
profile_account_name_map: Dict[str, str], |
|
config_account_region_map: Dict[str, str], |
|
) -> None: |
|
"""Write AWS CLI config""" |
|
session_name = ( |
|
args.session_name or urllib.parse.urlparse(args.start_url).netloc.split(".")[0] |
|
) |
|
print("") |
|
print("") |
|
print("Copy the snippet below into '$HOME/.aws/config':") |
|
print( |
|
f"{Color.YELLOW}; Run aws_cli_config_generator.py to get up-to-date AWS CLI profiles{Color.END}" |
|
) |
|
print( |
|
f"""{Color.YELLOW}[sso-session {session_name}] |
|
sso_start_url = {args.start_url} |
|
sso_region = {args.region} |
|
sso_registration_scopes = sso:account:access |
|
{Color.END}""", |
|
) |
|
|
|
profiles = {} |
|
for role in roles: |
|
if args.profile_name_template: |
|
profile_name = args.profile_name_template |
|
patterns = { |
|
"<account-name>": profile_account_name_map.get( |
|
role["accountName"], sanitize_name(role["accountName"]) |
|
), |
|
"<account-id>": role["accountId"], |
|
"<role-name>": profile_role_name_map.get( |
|
role["roleName"], role["roleName"] |
|
), |
|
} |
|
for pattern, replacement in patterns.items(): |
|
profile_name = profile_name.replace(pattern, replacement) |
|
else: |
|
profile_name = "-".join( |
|
[ |
|
session_name, |
|
profile_account_name_map.get( |
|
role["accountName"], sanitize_name(role["accountName"]) |
|
), |
|
profile_role_name_map.get(role["roleName"], role["roleName"]), |
|
] |
|
) |
|
|
|
profiles[profile_name] = f"""[profile {profile_name.lower()}] |
|
sso_session = {session_name} |
|
sso_account_id = {role["accountId"]} |
|
sso_role_name = {role["roleName"]} |
|
region = {config_account_region_map.get(role["accountId"], args.region)} |
|
""" |
|
for profile_name in sorted(profiles.keys()): |
|
print(Color.YELLOW + profiles[profile_name] + Color.END) |
|
|
|
|
|
def parse_mappings( |
|
mappings: list[str], |
|
) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]: |
|
profile_role_name_map = {} |
|
profile_account_name_map = {} |
|
config_account_region_map = {} |
|
|
|
for mapping in mappings: |
|
try: |
|
map_type, key, value = mapping.split(":", 2) |
|
if map_type == "profile-role-name": |
|
profile_role_name_map[key] = value |
|
elif map_type == "profile-account-name": |
|
profile_account_name_map[key] = value |
|
elif map_type == "config-account-region": |
|
config_account_region_map[key] = value |
|
else: |
|
print( |
|
f"Warning: Unknown mapping type '{map_type}' in '{mapping}'. Skipping." |
|
) |
|
except ValueError: |
|
print(f"Warning: Invalid mapping format '{mapping}'. Skipping.") |
|
|
|
return profile_role_name_map, profile_account_name_map, config_account_region_map |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser( |
|
description="Generate AWS CLI configuration for AWS SSO roles" |
|
) |
|
parser.add_argument("--session-name", help="An optional name of the SSO session") |
|
parser.add_argument( |
|
"--profile-name-template", |
|
help="An optional template to use when naming AWS CLI profiles", |
|
) |
|
parser.add_argument( |
|
"--start-url", |
|
required=True, |
|
help="The start URL of the AWS IAM Identity Center instance", |
|
) |
|
parser.add_argument( |
|
"--region", |
|
default="eu-west-1", |
|
help="The region of the AWS IAM Identity Center instance", |
|
) |
|
parser.add_argument( |
|
"--mapping", |
|
nargs="+", |
|
help="Optional mappings in the format 'type:key:value'. Types: profile-role-name, profile-account-name, config-account-region", |
|
metavar="TYPE:KEY:VALUE", |
|
) |
|
args = parser.parse_args() |
|
|
|
global_role_map, account_name_map, account_region_map = parse_mappings( |
|
args.mapping or [] |
|
) |
|
|
|
try: |
|
roles = get_available_roles(args.start_url, args.region) |
|
create_config( |
|
roles, args, global_role_map, account_name_map, account_region_map |
|
) |
|
except TimeoutError as e: |
|
print(f"\nError: {e}") |
|
exit(1) |
|
except Exception as e: |
|
print(f"\nUnexpected error: {e}") |
|
exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |