Skip to content

Instantly share code, notes, and snippets.

@huynhbaoan
Last active November 1, 2024 11:47
Show Gist options
  • Save huynhbaoan/c9747be3aa91cd0132717ed7fd302e23 to your computer and use it in GitHub Desktop.
Save huynhbaoan/c9747be3aa91cd0132717ed7fd302e23 to your computer and use it in GitHub Desktop.
import requests
import boto3
import getpass
import base64
import os
import csv
import xml.etree.ElementTree as ET
from bs4 import BeautifulSoup
import sys
from urllib3.exceptions import InsecureRequestWarning
import argparse
from typing import Dict, Optional, List
# Suppress SSL warnings
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
# Environment configurations for SAML authentication
ENV_CONFIGS = {
'np': {
'account_id': '803264201466',
'account_name': 'servicesnp',
'role_name': 'AUR-Resource-AWS-servicesnp-InfraAnalysis',
'environment': 'nonprod'
},
'preprod': {
'account_id': '150897553596',
'account_name': 'servicespreprod',
'role_name': 'AUR-Resource-AWS-servicespreprod-InfraAnalysis',
'environment': 'preprod'
},
'prod': {
'account_id': '522412867873',
'account_name': 'servicesprod',
'role_name': 'AUR-Resource-AWS-servicesprod-InfraAnalysis',
'environment': 'prod'
}
}
def load_and_filter_csv(csv_file: str, selected_envs: set) -> Dict[str, List[Dict[str, str]]]:
"""Load account mappings from CSV file, checking for missing environments and filtering by selected environments."""
account_mappings = {}
known_environments = {config['environment'] for config in ENV_CONFIGS.values()}
csv_environments = set()
try:
with open(csv_file, 'r') as f:
reader = csv.DictReader(f)
# Check if required columns exist
required_columns = {'Account ID', 'Environment', 'Description', 'URL'}
if not required_columns.issubset(reader.fieldnames):
missing = required_columns - set(reader.fieldnames)
sys.exit(f"CSV file is missing required columns: {', '.join(missing)}")
# Process each row and identify environments present in the CSV
for row in reader:
env = row['Environment']
csv_environments.add(env) # Track all environments in the CSV
if env not in known_environments:
print(f"Warning: Unknown environment '{env}' found in CSV. Skipping record: {row}")
continue
if env in selected_envs:
if env not in account_mappings:
account_mappings[env] = []
account_mappings[env].append({
'account_id': row['Account ID'],
'account_name': row['Environment'],
'env_type': row['Description'],
'url': row['URL'],
'switch_role_arn': f"arn:aws:iam::{row['Account ID']}:role/HIPViewOnlyRole"
})
except Exception as e:
sys.exit(f"Error reading CSV file: {str(e)}")
# Check for missing environments before filtering
missing_envs = csv_environments.intersection(known_environments) - selected_envs
if missing_envs:
print(f"Warning: The following environments are found in the CSV and known in configuration but were not specified in the input arguments: {', '.join(missing_envs)}")
sys.exit("Please include all necessary environments in the input arguments and try again.")
return account_mappings
def get_session(account_name, account_id, quiet=False):
"""Initialize a requests session with the IDP"""
# Construct the IDP URL
idp_url = f"https://idp.CDNXHG.com.au/noWebtop/CDNXHG-{account_name}/api/res?id=/Common/AWS-IdP"
# Add sn and acc parameters
sn = account_name.replace('aws-', '')
idp_url = f"{idp_url}&sn={sn}&acc={account_id}"
if not quiet:
print(f"IDP URL: {idp_url}")
# Initiate session handler
session = requests.Session()
try:
response = session.get(idp_url, verify=False, timeout=10)
except requests.exceptions.ConnectionError as err:
sys.exit(f"Failed to connect to {idp_url} " + repr(err))
except requests.exceptions.TooManyRedirects as err:
sys.exit(f"Too many redirects from {idp_url} " + repr(err))
except requests.exceptions.Timeout as err:
sys.exit(f"Connect timeout from {idp_url} " + repr(err))
if response.url == "https://idp.CDNXHG.com.au/vdesk/hangup.php3":
sys.exit('ERROR: Redirected to hangup URL. Invalid account name or ID.')
if response.url != "https://idp.CDNXHG.com.au/my.policy":
sys.exit(f'ERROR: Unexpected redirect URL: {response.url}')
return session, response
def authenticate(session, username, password, mfa_token=None, quiet=False):
"""Handle authentication with the IDP"""
try:
payload = {
'username': username,
'password': password,
'domain': 'aur.national.com.au',
'vhost': 'standard'
}
if mfa_token:
payload['token'] = mfa_token
try:
auth_response = session.post(
'https://idp.CDNXHG.com.au/my.policy',
data=payload,
verify=False,
headers={
'Content-Type': 'application/x-www-form-urlencoded',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/91.0.4472.124 Safari/537.36'
},
allow_redirects=True
)
auth_response.raise_for_status()
except requests.exceptions.ConnectionError as e:
if 'Connection aborted.' in str(e):
print("IDP closed the connection prematurely - trying again")
else:
raise
soup = BeautifulSoup(auth_response.text, 'html.parser')
assertion = soup.find('input', {'name': 'SAMLResponse'})
if assertion is not None:
return assertion.get('value')
if 'Session Expired/Timeout' in auth_response.text:
sys.exit("Session expired or timed out")
if 'Authentication failed' in auth_response.text:
sys.exit("Authentication failed - check your credentials")
sys.exit("No SAML assertion found in response")
except Exception as e:
sys.exit(f"Authentication error: {str(e)}")
def select_role(roles, target_role_name):
"""Select the appropriate role from available roles"""
print("\nAvailable roles:")
selected_index = None
for i, role in enumerate(roles):
print(f"[{i}]: {role}")
if target_role_name in role:
selected_index = i
if selected_index is not None:
print(f"\nAuto-selecting role: {roles[selected_index]}")
return roles[selected_index]
else:
print(f"\nWARNING: Target role '{target_role_name}' not found!")
while True:
try:
selection = int(input("Please select a role manually [0-{}]: ".format(len(roles)-1)))
if 0 <= selection < len(roles):
return roles[selection]
else:
print("Invalid selection. Try again.")
except ValueError:
print("Please enter a valid number.")
def verify_authentication(source_profile: str) -> bool:
"""Verify the authentication by calling GetCallerIdentity."""
try:
session = boto3.Session(profile_name=source_profile)
sts_client = session.client('sts')
# This call will confirm if the token is valid
identity = sts_client.get_caller_identity()
print(f"Authenticated as: {identity['Arn']}")
return True
except Exception as e:
print(f"Failed to verify authentication: {str(e)}")
return False
def write_credentials(profile_name: str, credentials: Dict):
"""Write AWS credentials to file while preserving other profiles"""
credentials_path = os.path.expanduser("~/.aws/credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
# Read existing credentials file
existing_credentials = {}
current_profile = None
if os.path.exists(credentials_path):
try:
with open(credentials_path, 'r') as f:
for line in f:
line = line.strip()
if line: # Skip empty lines
if line.startswith('[') and line.endswith(']'):
current_profile = line[1:-1]
existing_credentials[current_profile] = []
elif current_profile:
existing_credentials[current_profile].append(line)
except Exception as e:
print(f"Warning: Could not read existing credentials file: {str(e)}")
# Update or add new profile with account_name as the profile name
existing_credentials[profile_name] = [
"region = ap-southeast-2",
f"aws_access_key_id = {credentials['AccessKeyId']}",
f"aws_secret_access_key = {credentials['SecretAccessKey']}",
f"aws_session_token = {credentials['SessionToken']}",
f"aws_security_token = {credentials['SessionToken']}" # Same as session token
]
# Write back all profiles
try:
with open(credentials_path, 'w') as f:
for profile, lines in existing_credentials.items():
f.write(f"[{profile}]\n")
for line in lines:
f.write(f"{line}\n")
f.write("\n") # Add empty line between profiles
except Exception as e:
sys.exit(f"Error writing credentials file: {str(e)}")
def switch_role(source_profile: str, target_role_arn: str, session_name: str) -> Dict:
"""Assume a role using existing credentials"""
try:
# Use the session with `source_profile` to create an STS client
session = boto3.Session(profile_name=source_profile)
sts_client = session.client('sts')
# Assume the target role
response = sts_client.assume_role(
RoleArn=target_role_arn,
RoleSessionName=session_name
)
return response['Credentials']
except Exception as e:
raise Exception(f"Failed to switch role: {str(e)}")
def main():
parser = argparse.ArgumentParser(description='AWS SAML Authentication and Role Switching Tool')
parser.add_argument('--env', nargs='+', choices=['np', 'preprod', 'prod'],
required=True, help='Service environments to authenticate (e.g., np preprod)')
parser.add_argument('--csv', required=True,
help='Path to CSV file containing account mappings')
parser.add_argument('--username', '-u', help='Username (optional, defaults to $USER)')
parser.add_argument('--quiet', '-q', action='store_true', help='Minimize output')
parser.add_argument('--mfa-token', help='MFA Token (optional)')
args = parser.parse_args()
print("CDNXHG HIS AWS CLI Access Tool")
print("===========================\n")
# Map environment arguments to target environments
selected_envs = {ENV_CONFIGS[env]['environment'] for env in args.env}
# Load and filter CSV content based on selected environments
print("\nLoading and validating CSV content...")
account_mappings = load_and_filter_csv(args.csv, selected_envs)
# Get username and prompt for password
username = args.username or os.environ.get('USER') or input("Enter username: ")
password = getpass.getpass("Password: ")
mfa_token = args.mfa_token or input("MFA Token (Optional): [Press enter to skip] ")
# Authenticate and process each environment
for env_code in args.env:
env_config = ENV_CONFIGS[env_code]
env_name = env_config['environment']
print(f"\nAuthenticating to {env_config['account_name']} for environment '{env_name}'")
# Initial authentication
session, response = get_session(env_config['account_name'], env_config['account_id'], args.quiet)
saml_response = authenticate(session, username, password, mfa_token, args.quiet)
# Verify authentication before proceeding
if not verify_authentication(env_config['account_name']):
sys.exit("Authentication verification failed. Please check your credentials and try again.")
print(f"\nSwitching roles for environment '{env_name}'")
for account in account_mappings.get(env_name, []):
try:
print(f"\nSwitching to role in account: {account['account_name']} ({account['account_id']})")
credentials = switch_role(
source_profile=env_config['account_name'],
target_role_arn=account['switch_role_arn'],
session_name=f"{username}-session",
env_config=env_config # Pass env_config for re-authentication if needed
)
write_credentials(account['account_name'], credentials)
print(f"Successfully switched to role in {account['account_name']}")
print(f"To use these credentials: export AWS_PROFILE={account['account_name']}")
print(f"Credentials expire: {credentials['Expiration']}")
except Exception as e:
print(f"Error switching to account {account['account_name']}: {str(e)}")
print("Continuing with next account...")
continue
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment