Last active
October 11, 2022 19:23
-
-
Save gngdb/f968293a84765c0f9a4b8ae4c69551ba to your computer and use it in GitHub Desktop.
Script to do the same thing as https://cloud.google.com/sdk/gcloud/reference/compute/config-ssh but works for TPU VMs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# script to add gcloud instances to ssh config | |
# | |
# Usage: tpu_configssh.py <instance_name> <instance_name> ... | |
# | |
# A version of this exists in gcloud compute config-ssh but it doesn't work for TPU VMs | |
# | |
# Works by parsing the output of dryrun mode of gcloud compute ssh, example: | |
# $ gcloud alpha compute tpus tpu-vm ssh instance-name --dry-run | |
# /usr/bin/ssh -t -i /home/user/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=<alias> -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=/home/user/.ssh/google_compute_known_hosts user@IP | |
import sys | |
import subprocess | |
import argparse | |
import shutil | |
from pathlib import Path | |
def pairwise(iterable): | |
"s -> (s0, s1), (s2, s3), (s4, s5), ..." | |
a = iter(iterable) | |
while True: | |
try: | |
yield next(a), next(a) | |
except StopIteration: | |
return | |
def get_dryrun_info(instance_name): | |
cmd = ["gcloud", "alpha", "compute", "tpus", "tpu-vm", "ssh", instance_name, "--dry-run"] | |
output = subprocess.check_output(cmd).decode("utf-8") | |
lines = output.splitlines() | |
for line in lines: | |
if line.startswith("/usr/bin/ssh"): | |
return [p for p in pairwise(line.split(" ")[2:])] + [("Destination", line.split(" ")[-1])] | |
def parse_options(dryrun_info): | |
options = {} | |
for key, value in dryrun_info: | |
if key.startswith("-o"): | |
name, setting = value.split("=") | |
options[name] = setting | |
elif key.startswith("-i"): | |
options["IdentityFile"] = value | |
elif key.startswith("Destination"): | |
user, host = value.split("@") | |
options["User"] = user | |
options["HostName"] = host | |
return options | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("instance_names", nargs="+") | |
parser.add_argument("--forward-agent", action="store_true") | |
return parser.parse_args() | |
if __name__ == '__main__': | |
args = parse_args() | |
for instance_name in args.instance_names: | |
print("Adding instance {} to ssh config".format(instance_name)) | |
options = parse_options(get_dryrun_info(instance_name)) | |
if args.forward_agent: | |
options["ForwardAgent"] = "yes" | |
print(options) | |
print("Copying ~/.ssh/config to ~/.ssh/config.bak") | |
shutil.copyfile(Path.home()/".ssh"/"config", Path.home()/".ssh"/"config.bak") | |
print("Adding instance to ~/.ssh/config") | |
with open(Path.home()/".ssh"/"config", "a") as f: | |
f.write("\n## added by tpu_configssh.py ##\n") | |
f.write("Host {}\n".format(instance_name)) | |
for name, setting in options.items(): | |
f.write(" {} {}\n".format(name, setting)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment