Skip to content

Instantly share code, notes, and snippets.

@eoinsha
Created April 26, 2024 13:12
Show Gist options
  • Save eoinsha/157f6d869d0033f80a8da5757e8781f7 to your computer and use it in GitHub Desktop.
Save eoinsha/157f6d869d0033f80a8da5757e8781f7 to your computer and use it in GitHub Desktop.
Script to aid tunnelling to a Bastion ECS container and run `psql`
#!/usr/bin/env python3
from functools import cache
import os
import sys
from typing import Literal
import json
import click
import boto3
from rich.console import Console
from rich.table import Table
session = boto3.session.Session()
rds_client = session.client("rds")
ecs_client = session.client("ecs")
ssm_client = session.client("ssm")
sts_client = session.client("sts")
console = Console()
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Parameter Name", justify="left")
table.add_column("Value", justify="left", no_wrap=False, overflow="fold")
account = sts_client.get_caller_identity()["Account"]
# These variables are assumed. Others come from SSM Parameter Store
db_port = 5432
local_port = 15432
db_user = "dbuser"
stack_prefix = "DbAccessStack"
secret_name = f"db-admin-secret-{account}-{session.region_name}"
@click.group()
def cli():
pass
@cli.command()
def show():
"""Show database access details
Prints database parameters along with commands useful
for SSM shells and SSM tunnelling.
"""
db_params = get_db_params()
for param_name, param_value in db_params.items():
table.add_row(param_name.split("/")[-1], param_value)
console.print(table)
db_host = db_params["/db/ClusterEndpoint"].split(":")[0]
cluster_name = None
cluster_arns = ecs_client.list_clusters()["clusterArns"]
cluster_names = [cluster_arn.split("/")[-1] for cluster_arn in cluster_arns]
matching_cluster_names = [
cluster_name
for cluster_name in cluster_names
if cluster_name.startswith(stack_prefix)
]
if not matching_cluster_names:
print(
f"No cluster found with name starting with '{stack_name}'.", file=sys.stderr
)
sys.exit(1)
cluster_name = matching_cluster_names[0]
tasks = ecs_client.list_tasks(cluster=cluster_name)["taskArns"]
if not tasks:
print(f"No tasks found in cluster {cluster_name}", file=sys.stderr)
sys.exit(1)
task_id = tasks[0].split("/")[-1]
task = ecs_client.describe_tasks(cluster=cluster_name, tasks=[task_id])["tasks"][0]
container = task["containers"][0]
container_runtime_id = container.get("runtimeId")
if container_runtime_id is None:
print("Container not (yet?) running", file=sys.stderr)
sys.exit(1)
container_name = container["name"]
console.print("Command to shell into bastion:", style="bold cyan")
print(
f"""
aws ecs execute-command --cluster {cluster_name} \
--task {task_id} \
--container {container_name} \
--interactive \
--command "/bin/sh"
"""
)
param_json = json.dumps(
{
"host": [db_host],
"portNumber": [str(db_port)],
"localPortNumber": [str(local_port)],
}
)
target_id = f"ecs:{cluster_name}_{task_id}_{container_runtime_id}"
console.print("Command to tunnel to DB cluster:", style="bold cyan")
print(
f"""
aws ssm start-session --target {target_id} \\
--document-name AWS-StartPortForwardingSessionToRemoteHost \\
--parameters '{param_json}'
"""
)
def run(cmd, envs):
"""Run an interactive shell application"""
cmd = cmd.split()
code = os.spawnvpe(os.P_WAIT, cmd[0], cmd, {**os.environ, **envs})
if code == 127:
sys.stderr.write("{0}: command not found\n".format(cmd[0]))
return code
def get_rds_iam_password():
"""Fetch an RDS password from IAM"""
db_params = get_db_params()
db_host = db_params["/db/ClusterEndpoint"].split(":")[0]
password = rds_client.generate_db_auth_token(
DBHostname=db_host,
Port=db_port,
DBUsername=db_user,
Region=session.region_name,
)
return password
def get_secrets_manager_secret():
"""Fetch an RDS secret password from SecretsManager"""
secrets_client = session.client("secretsmanager")
secret_payload = secrets_client.get_secret_value(SecretId=secret_name)[
"SecretString"
]
secret_value = json.loads(secret_payload)["password"]
return secret_value
@cli.command()
@click.option(
"--with-password",
type=click.Choice(["none", "iam", "secretsmanager"]),
default="none",
)
@click.option("--connect", is_flag=True, default=False)
def psql(
with_password: Literal["none", "iam", "secretsmanager"], connect: bool = False
):
"""
Generate a psql command
Generates a psql command (and optionally runs) for use with the local tunnel"""
password = "PASSWORD"
if with_password == "iam":
password = get_rds_iam_password()
elif with_password == "secretsmanager":
password = get_secrets_manager_secret()
cmd = f"psql -h localhost -p {local_port} -U {db_user} main"
if connect:
run(cmd, {"PGPASSWORD": password})
else:
print(f"PGPASSWORD='{password}' {cmd}")
@cache
def get_db_params():
return {
param["Name"]: param["Value"]
for param in ssm_client.get_parameters_by_path(Path="/db", Recursive=True)[
"Parameters"
]
}
if __name__ == "__main__":
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment