Created
April 26, 2024 13:12
-
-
Save eoinsha/157f6d869d0033f80a8da5757e8781f7 to your computer and use it in GitHub Desktop.
Script to aid tunnelling to a Bastion ECS container and run `psql`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from functools import cache | |
import os | |
import sys | |
from typing import Literal | |
import json | |
import click | |
import boto3 | |
from rich.console import Console | |
from rich.table import Table | |
session = boto3.session.Session() | |
rds_client = session.client("rds") | |
ecs_client = session.client("ecs") | |
ssm_client = session.client("ssm") | |
sts_client = session.client("sts") | |
console = Console() | |
table = Table(show_header=True, header_style="bold magenta") | |
table.add_column("Parameter Name", justify="left") | |
table.add_column("Value", justify="left", no_wrap=False, overflow="fold") | |
account = sts_client.get_caller_identity()["Account"] | |
# These variables are assumed. Others come from SSM Parameter Store | |
db_port = 5432 | |
local_port = 15432 | |
db_user = "dbuser" | |
stack_prefix = "DbAccessStack" | |
secret_name = f"db-admin-secret-{account}-{session.region_name}" | |
@click.group() | |
def cli(): | |
pass | |
@cli.command() | |
def show(): | |
"""Show database access details | |
Prints database parameters along with commands useful | |
for SSM shells and SSM tunnelling. | |
""" | |
db_params = get_db_params() | |
for param_name, param_value in db_params.items(): | |
table.add_row(param_name.split("/")[-1], param_value) | |
console.print(table) | |
db_host = db_params["/db/ClusterEndpoint"].split(":")[0] | |
cluster_name = None | |
cluster_arns = ecs_client.list_clusters()["clusterArns"] | |
cluster_names = [cluster_arn.split("/")[-1] for cluster_arn in cluster_arns] | |
matching_cluster_names = [ | |
cluster_name | |
for cluster_name in cluster_names | |
if cluster_name.startswith(stack_prefix) | |
] | |
if not matching_cluster_names: | |
print( | |
f"No cluster found with name starting with '{stack_name}'.", file=sys.stderr | |
) | |
sys.exit(1) | |
cluster_name = matching_cluster_names[0] | |
tasks = ecs_client.list_tasks(cluster=cluster_name)["taskArns"] | |
if not tasks: | |
print(f"No tasks found in cluster {cluster_name}", file=sys.stderr) | |
sys.exit(1) | |
task_id = tasks[0].split("/")[-1] | |
task = ecs_client.describe_tasks(cluster=cluster_name, tasks=[task_id])["tasks"][0] | |
container = task["containers"][0] | |
container_runtime_id = container.get("runtimeId") | |
if container_runtime_id is None: | |
print("Container not (yet?) running", file=sys.stderr) | |
sys.exit(1) | |
container_name = container["name"] | |
console.print("Command to shell into bastion:", style="bold cyan") | |
print( | |
f""" | |
aws ecs execute-command --cluster {cluster_name} \ | |
--task {task_id} \ | |
--container {container_name} \ | |
--interactive \ | |
--command "/bin/sh" | |
""" | |
) | |
param_json = json.dumps( | |
{ | |
"host": [db_host], | |
"portNumber": [str(db_port)], | |
"localPortNumber": [str(local_port)], | |
} | |
) | |
target_id = f"ecs:{cluster_name}_{task_id}_{container_runtime_id}" | |
console.print("Command to tunnel to DB cluster:", style="bold cyan") | |
print( | |
f""" | |
aws ssm start-session --target {target_id} \\ | |
--document-name AWS-StartPortForwardingSessionToRemoteHost \\ | |
--parameters '{param_json}' | |
""" | |
) | |
def run(cmd, envs): | |
"""Run an interactive shell application""" | |
cmd = cmd.split() | |
code = os.spawnvpe(os.P_WAIT, cmd[0], cmd, {**os.environ, **envs}) | |
if code == 127: | |
sys.stderr.write("{0}: command not found\n".format(cmd[0])) | |
return code | |
def get_rds_iam_password(): | |
"""Fetch an RDS password from IAM""" | |
db_params = get_db_params() | |
db_host = db_params["/db/ClusterEndpoint"].split(":")[0] | |
password = rds_client.generate_db_auth_token( | |
DBHostname=db_host, | |
Port=db_port, | |
DBUsername=db_user, | |
Region=session.region_name, | |
) | |
return password | |
def get_secrets_manager_secret(): | |
"""Fetch an RDS secret password from SecretsManager""" | |
secrets_client = session.client("secretsmanager") | |
secret_payload = secrets_client.get_secret_value(SecretId=secret_name)[ | |
"SecretString" | |
] | |
secret_value = json.loads(secret_payload)["password"] | |
return secret_value | |
@cli.command() | |
@click.option( | |
"--with-password", | |
type=click.Choice(["none", "iam", "secretsmanager"]), | |
default="none", | |
) | |
@click.option("--connect", is_flag=True, default=False) | |
def psql( | |
with_password: Literal["none", "iam", "secretsmanager"], connect: bool = False | |
): | |
""" | |
Generate a psql command | |
Generates a psql command (and optionally runs) for use with the local tunnel""" | |
password = "PASSWORD" | |
if with_password == "iam": | |
password = get_rds_iam_password() | |
elif with_password == "secretsmanager": | |
password = get_secrets_manager_secret() | |
cmd = f"psql -h localhost -p {local_port} -U {db_user} main" | |
if connect: | |
run(cmd, {"PGPASSWORD": password}) | |
else: | |
print(f"PGPASSWORD='{password}' {cmd}") | |
@cache | |
def get_db_params(): | |
return { | |
param["Name"]: param["Value"] | |
for param in ssm_client.get_parameters_by_path(Path="/db", Recursive=True)[ | |
"Parameters" | |
] | |
} | |
if __name__ == "__main__": | |
cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment