-
-
Save huynhbaoan/c9747be3aa91cd0132717ed7fd302e23 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import boto3 | |
import getpass | |
import base64 | |
import os | |
import csv | |
import xml.etree.ElementTree as ET | |
from bs4 import BeautifulSoup | |
import sys | |
from urllib3.exceptions import InsecureRequestWarning | |
import argparse | |
from typing import Dict, Optional, List | |
# Suppress SSL warnings | |
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) | |
# Environment configurations for SAML authentication | |
ENV_CONFIGS = { | |
'np': { | |
'account_id': '803264201466', | |
'account_name': 'servicesnp', | |
'role_name': 'AUR-Resource-AWS-servicesnp-InfraAnalysis', | |
'environment': 'nonprod' | |
}, | |
'preprod': { | |
'account_id': '150897553596', | |
'account_name': 'servicespreprod', | |
'role_name': 'AUR-Resource-AWS-servicespreprod-InfraAnalysis', | |
'environment': 'preprod' | |
}, | |
'prod': { | |
'account_id': '522412867873', | |
'account_name': 'servicesprod', | |
'role_name': 'AUR-Resource-AWS-servicesprod-InfraAnalysis', | |
'environment': 'prod' | |
} | |
} | |
def load_and_filter_csv(csv_file: str, selected_envs: set) -> Dict[str, List[Dict[str, str]]]: | |
"""Load account mappings from CSV file, checking for missing environments and filtering by selected environments.""" | |
account_mappings = {} | |
known_environments = {config['environment'] for config in ENV_CONFIGS.values()} | |
csv_environments = set() | |
try: | |
with open(csv_file, 'r') as f: | |
reader = csv.DictReader(f) | |
# Check if required columns exist | |
required_columns = {'Account ID', 'Environment', 'Description', 'URL'} | |
if not required_columns.issubset(reader.fieldnames): | |
missing = required_columns - set(reader.fieldnames) | |
sys.exit(f"CSV file is missing required columns: {', '.join(missing)}") | |
# Process each row and identify environments present in the CSV | |
for row in reader: | |
env = row['Environment'] | |
csv_environments.add(env) # Track all environments in the CSV | |
if env not in known_environments: | |
print(f"Warning: Unknown environment '{env}' found in CSV. Skipping record: {row}") | |
continue | |
if env in selected_envs: | |
if env not in account_mappings: | |
account_mappings[env] = [] | |
account_mappings[env].append({ | |
'account_id': row['Account ID'], | |
'account_name': row['Environment'], | |
'env_type': row['Description'], | |
'url': row['URL'], | |
'switch_role_arn': f"arn:aws:iam::{row['Account ID']}:role/HIPViewOnlyRole" | |
}) | |
except Exception as e: | |
sys.exit(f"Error reading CSV file: {str(e)}") | |
# Check for missing environments before filtering | |
missing_envs = csv_environments.intersection(known_environments) - selected_envs | |
if missing_envs: | |
print(f"Warning: The following environments are found in the CSV and known in configuration but were not specified in the input arguments: {', '.join(missing_envs)}") | |
sys.exit("Please include all necessary environments in the input arguments and try again.") | |
return account_mappings | |
def get_session(account_name, account_id, quiet=False): | |
"""Initialize a requests session with the IDP""" | |
# Construct the IDP URL | |
idp_url = f"https://idp.CDNXHG.com.au/noWebtop/CDNXHG-{account_name}/api/res?id=/Common/AWS-IdP" | |
# Add sn and acc parameters | |
sn = account_name.replace('aws-', '') | |
idp_url = f"{idp_url}&sn={sn}&acc={account_id}" | |
if not quiet: | |
print(f"IDP URL: {idp_url}") | |
# Initiate session handler | |
session = requests.Session() | |
try: | |
response = session.get(idp_url, verify=False, timeout=10) | |
except requests.exceptions.ConnectionError as err: | |
sys.exit(f"Failed to connect to {idp_url} " + repr(err)) | |
except requests.exceptions.TooManyRedirects as err: | |
sys.exit(f"Too many redirects from {idp_url} " + repr(err)) | |
except requests.exceptions.Timeout as err: | |
sys.exit(f"Connect timeout from {idp_url} " + repr(err)) | |
if response.url == "https://idp.CDNXHG.com.au/vdesk/hangup.php3": | |
sys.exit('ERROR: Redirected to hangup URL. Invalid account name or ID.') | |
if response.url != "https://idp.CDNXHG.com.au/my.policy": | |
sys.exit(f'ERROR: Unexpected redirect URL: {response.url}') | |
return session, response | |
def authenticate(session, username, password, mfa_token=None, quiet=False): | |
"""Handle authentication with the IDP""" | |
try: | |
payload = { | |
'username': username, | |
'password': password, | |
'domain': 'aur.national.com.au', | |
'vhost': 'standard' | |
} | |
if mfa_token: | |
payload['token'] = mfa_token | |
try: | |
auth_response = session.post( | |
'https://idp.CDNXHG.com.au/my.policy', | |
data=payload, | |
verify=False, | |
headers={ | |
'Content-Type': 'application/x-www-form-urlencoded', | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/91.0.4472.124 Safari/537.36' | |
}, | |
allow_redirects=True | |
) | |
auth_response.raise_for_status() | |
except requests.exceptions.ConnectionError as e: | |
if 'Connection aborted.' in str(e): | |
print("IDP closed the connection prematurely - trying again") | |
else: | |
raise | |
soup = BeautifulSoup(auth_response.text, 'html.parser') | |
assertion = soup.find('input', {'name': 'SAMLResponse'}) | |
if assertion is not None: | |
return assertion.get('value') | |
if 'Session Expired/Timeout' in auth_response.text: | |
sys.exit("Session expired or timed out") | |
if 'Authentication failed' in auth_response.text: | |
sys.exit("Authentication failed - check your credentials") | |
sys.exit("No SAML assertion found in response") | |
except Exception as e: | |
sys.exit(f"Authentication error: {str(e)}") | |
def select_role(roles, target_role_name): | |
"""Select the appropriate role from available roles""" | |
print("\nAvailable roles:") | |
selected_index = None | |
for i, role in enumerate(roles): | |
print(f"[{i}]: {role}") | |
if target_role_name in role: | |
selected_index = i | |
if selected_index is not None: | |
print(f"\nAuto-selecting role: {roles[selected_index]}") | |
return roles[selected_index] | |
else: | |
print(f"\nWARNING: Target role '{target_role_name}' not found!") | |
while True: | |
try: | |
selection = int(input("Please select a role manually [0-{}]: ".format(len(roles)-1))) | |
if 0 <= selection < len(roles): | |
return roles[selection] | |
else: | |
print("Invalid selection. Try again.") | |
except ValueError: | |
print("Please enter a valid number.") | |
def verify_authentication(source_profile: str) -> bool: | |
"""Verify the authentication by calling GetCallerIdentity.""" | |
try: | |
session = boto3.Session(profile_name=source_profile) | |
sts_client = session.client('sts') | |
# This call will confirm if the token is valid | |
identity = sts_client.get_caller_identity() | |
print(f"Authenticated as: {identity['Arn']}") | |
return True | |
except Exception as e: | |
print(f"Failed to verify authentication: {str(e)}") | |
return False | |
def write_credentials(profile_name: str, credentials: Dict): | |
"""Write AWS credentials to file while preserving other profiles""" | |
credentials_path = os.path.expanduser("~/.aws/credentials") | |
os.makedirs(os.path.dirname(credentials_path), exist_ok=True) | |
# Read existing credentials file | |
existing_credentials = {} | |
current_profile = None | |
if os.path.exists(credentials_path): | |
try: | |
with open(credentials_path, 'r') as f: | |
for line in f: | |
line = line.strip() | |
if line: # Skip empty lines | |
if line.startswith('[') and line.endswith(']'): | |
current_profile = line[1:-1] | |
existing_credentials[current_profile] = [] | |
elif current_profile: | |
existing_credentials[current_profile].append(line) | |
except Exception as e: | |
print(f"Warning: Could not read existing credentials file: {str(e)}") | |
# Update or add new profile with account_name as the profile name | |
existing_credentials[profile_name] = [ | |
"region = ap-southeast-2", | |
f"aws_access_key_id = {credentials['AccessKeyId']}", | |
f"aws_secret_access_key = {credentials['SecretAccessKey']}", | |
f"aws_session_token = {credentials['SessionToken']}", | |
f"aws_security_token = {credentials['SessionToken']}" # Same as session token | |
] | |
# Write back all profiles | |
try: | |
with open(credentials_path, 'w') as f: | |
for profile, lines in existing_credentials.items(): | |
f.write(f"[{profile}]\n") | |
for line in lines: | |
f.write(f"{line}\n") | |
f.write("\n") # Add empty line between profiles | |
except Exception as e: | |
sys.exit(f"Error writing credentials file: {str(e)}") | |
def switch_role(source_profile: str, target_role_arn: str, session_name: str) -> Dict: | |
"""Assume a role using existing credentials""" | |
try: | |
# Use the session with `source_profile` to create an STS client | |
session = boto3.Session(profile_name=source_profile) | |
sts_client = session.client('sts') | |
# Assume the target role | |
response = sts_client.assume_role( | |
RoleArn=target_role_arn, | |
RoleSessionName=session_name | |
) | |
return response['Credentials'] | |
except Exception as e: | |
raise Exception(f"Failed to switch role: {str(e)}") | |
def main(): | |
parser = argparse.ArgumentParser(description='AWS SAML Authentication and Role Switching Tool') | |
parser.add_argument('--env', nargs='+', choices=['np', 'preprod', 'prod'], | |
required=True, help='Service environments to authenticate (e.g., np preprod)') | |
parser.add_argument('--csv', required=True, | |
help='Path to CSV file containing account mappings') | |
parser.add_argument('--username', '-u', help='Username (optional, defaults to $USER)') | |
parser.add_argument('--quiet', '-q', action='store_true', help='Minimize output') | |
parser.add_argument('--mfa-token', help='MFA Token (optional)') | |
args = parser.parse_args() | |
print("CDNXHG HIS AWS CLI Access Tool") | |
print("===========================\n") | |
# Map environment arguments to target environments | |
selected_envs = {ENV_CONFIGS[env]['environment'] for env in args.env} | |
# Load and filter CSV content based on selected environments | |
print("\nLoading and validating CSV content...") | |
account_mappings = load_and_filter_csv(args.csv, selected_envs) | |
# Get username and prompt for password | |
username = args.username or os.environ.get('USER') or input("Enter username: ") | |
password = getpass.getpass("Password: ") | |
mfa_token = args.mfa_token or input("MFA Token (Optional): [Press enter to skip] ") | |
# Authenticate and process each environment | |
for env_code in args.env: | |
env_config = ENV_CONFIGS[env_code] | |
env_name = env_config['environment'] | |
print(f"\nAuthenticating to {env_config['account_name']} for environment '{env_name}'") | |
# Initial authentication | |
session, response = get_session(env_config['account_name'], env_config['account_id'], args.quiet) | |
saml_response = authenticate(session, username, password, mfa_token, args.quiet) | |
# Verify authentication before proceeding | |
if not verify_authentication(env_config['account_name']): | |
sys.exit("Authentication verification failed. Please check your credentials and try again.") | |
print(f"\nSwitching roles for environment '{env_name}'") | |
for account in account_mappings.get(env_name, []): | |
try: | |
print(f"\nSwitching to role in account: {account['account_name']} ({account['account_id']})") | |
credentials = switch_role( | |
source_profile=env_config['account_name'], | |
target_role_arn=account['switch_role_arn'], | |
session_name=f"{username}-session", | |
env_config=env_config # Pass env_config for re-authentication if needed | |
) | |
write_credentials(account['account_name'], credentials) | |
print(f"Successfully switched to role in {account['account_name']}") | |
print(f"To use these credentials: export AWS_PROFILE={account['account_name']}") | |
print(f"Credentials expire: {credentials['Expiration']}") | |
except Exception as e: | |
print(f"Error switching to account {account['account_name']}: {str(e)}") | |
print("Continuing with next account...") | |
continue | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment