Last active
August 13, 2024 14:41
-
-
Save niuniulla/fb16c4534d94b4a92226f50e43a39f6c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
################################################## | |
# File: 2024-08-12 | |
# To launch python script on remote server | |
# | |
# Usage: python remote_launch.py --log_file log_credentials.yaml --host '192.168.0.209' --script myscript.py | |
# The current script should be in the same directory as credentials.yaml, the script and bash_run.sh. | |
# arguments: | |
# - log_file: The path to the log file containing 'username' and 'password' to connect to remote point. | |
# This is a YAML file with the following structure: | |
# username: yourname | |
# password: pthe_password_here | |
# - host: The IP address or hostname of the remote server. | |
# - script: The python script to be executed on the remote server. | |
# This is needed to be copied to the remote server. | |
# - port: The port number for SSH connection. Default is 22. | |
# - project: The name of the project. This is used to create a directory with the | |
# same name on the remote server. By default, it is assigned a uuid. | |
# - env: The name of the environment variables to be set on the remote server | |
# to run the script. | |
# if not provided, it defaults to the name project + "_env". | |
# - requirements: The file containing a list of python packages required to run the script. | |
# The default is requirements.txt. | |
# | |
# On the remote server, each user will have a directory with the same name as username created under /tmp. | |
# All related files, directories and output files will be created and managed under this directory. | |
# For a user with name 'john', the directory will be /tmp/john and by running the script with a project name "project1", | |
# The directory structure under this user is: | |
# | |
# /tmp/john | |
# ├── project1: the project directory of name 'project1' | |
# | | |
# ├── data (s): a directory containing all data files created locally and mounted to the remote server. | |
# ├── outputs (s): a directory created remotely and mounted locally. it contains all output files. | |
# ├── bash_run.sh (s): bash file to run the script. | |
# ├── myscript.py (s): python file to be executed on the remote server. | |
# ├── requirements.txt (s): a file containing python packages required to run the script. | |
# ├── project1_env (s): the python environement to run the script. | |
# ├── log_credentials.yaml (s): the log file containing 'username' and 'password' to connect to remote server. | |
# └── launcher_log.log: log file for the launcher script automatically created by logger. It contains all information about the script execution and script outputs | |
# ├── project2: the project directory of name 'project2' | |
# ... | |
# | |
# The dirs and files marked as (s) are synchronized between local and remote environments. | |
# Otherwise, they only exist in the local machine. | |
# | |
# The dirs 'data' and 'outputs' are created by default: | |
# - data is ususally used to store data files. It is a local directory that is mounted to | |
# the remote server. So, if the python program requires extensive data reading, this may | |
# not be a good strategy, but this can prevent unnecessary disk space usage and data loss. | |
# - outputs is used to store all output files from script. It is a remote directory that is | |
# mounted to the local machine. So if the connection is lost, user can't access the output | |
# files anymore, to avoid this, user can copy the files manually. | |
################################################## | |
from sys import platform | |
import sys, os, paramiko, yaml, argparse, logging, uuid, subprocess, socket, getpass, time | |
def read_log_file(_file_path): | |
""" | |
Example function to read yaml and extract username and password | |
""" | |
if not os.path.isfile(_file_path): | |
logging.error(f"Log file '{_file_path}' does not exist.") | |
return None | |
username, password = "", "" | |
stream = open(_file_path, 'r') | |
items = yaml.safe_load(stream) | |
for key, value in items.items(): | |
if key == 'username': | |
username = value | |
if key == 'password': | |
password = value | |
if username == "" or password == "": | |
logging.error("Log file does not contain 'username' or 'password' key.") | |
return None | |
else: | |
return (username, password) | |
def mkdir_p(sftp, remote_directory): | |
""" | |
make directories recursively | |
""" | |
if remote_directory == '/': | |
# absolute path so change directory to root | |
sftp.chdir('/') | |
return | |
if remote_directory == '': | |
# top-level relative directory must exist | |
return | |
try: | |
sftp.chdir(remote_directory) # sub-directory exists | |
except IOError: | |
dirname, basename = os.path.split(remote_directory.rstrip('/')) | |
mkdir_p(sftp, dirname) # make parent directories | |
sftp.mkdir(basename) # sub-directory missing, so created it | |
sftp.chdir(basename) | |
return True | |
def get_args(): | |
""" | |
arguments parser. | |
""" | |
parser = argparse.ArgumentParser(description='Training on server.') | |
parser.add_argument('--log_file', | |
help='Log file path.') | |
parser.add_argument('--host', | |
help='Host name or IP.') | |
parser.add_argument('--port', nargs="?", type=int, default=22, | |
help='Port number for SSH connection.') | |
parser.add_argument('--project', | |
help='Project name to identify the task to be launched.') | |
parser.add_argument('--script', | |
help='The file to execute on server.') | |
parser.add_argument('--env', | |
help='environment name to be used to run the script.') | |
parser.add_argument('--requirements', default="requirements.txt", | |
help='The file containing the python packages used by script.') | |
parser.add_argument('--data_dir', default="data", help='Local data directory.') | |
parser.add_argument('--output_dir', default="outputs", help='Local output directory.') | |
return parser.parse_args() | |
def main(): | |
""" | |
main function to setup environment and launch script on remote server. | |
""" | |
# setup logging to file | |
logging.basicConfig(filename='train_launcher_log.log', | |
encoding="utf-8", | |
filemode="w", | |
format="{asctime} - {module} - {levelname} - {message}", | |
style="{", | |
datefmt="%Y-%m-%d %H:%M", | |
level=logging.INFO | |
) | |
# argument parsing for command-line arguments | |
args = get_args() | |
# local dirs | |
data_dir = "data" | |
output_dir = "outputs" | |
current_dir = os.getcwd() | |
logging.info(f"Local - Current directory: {current_dir}.") | |
# local network | |
ipname = socket.gethostname() | |
myIP = socket.gethostbyname(ipname) | |
osuser = getpass.getuser() | |
ospwd = getpass.getpass(prompt="Enter password for {}: ".format(osuser)) | |
logging.info(f"Local - Username: {osuser}, Local IP: {myIP}.") | |
# check for local system | |
if platform == "linux" or platform == "linux2": | |
logging.info(f"Local - System platform: {platform}") | |
else: | |
logging.error("Local - Systems other than Linux are not supported.") | |
sys.exit(1) | |
# check for host | |
if args.host == None: | |
logging.error("No host provided.") | |
sys.exit(1) | |
else: | |
logging.info(f"Will connect to host: {args.host}.") | |
# get credentials from log file if provided | |
username, password = "", "" | |
if args.log_file: | |
logging.info(f"Get remote credentials from log file: {args.log_file}.") | |
output = read_log_file(args.log_file) | |
if output != None: | |
username, password = output | |
logging.info(f"Will use remote credentials for user: {output[0]}.") | |
else: | |
logging.warning("No log file provided. Will use default log crediential.") | |
username, password = os.getenv('SSH_USERNAME'), os.getenv('SSH_PASSWORD') | |
if username == None or password == None: | |
logging.error("No 'SSH_USERNAME' or 'SSH_PASSWORD' environment variables found.") | |
sys.exit(1) | |
else: | |
logging.info(f"Will use remote credentials for user: {username} for longin.") | |
# set working dir | |
project = "" | |
if args.project is None: | |
project = uuid.uuid4().hex[:8] | |
else: | |
project = args.project | |
working_dir = "/tmp/" + username + "/" + project | |
logging.info(f"The remote working directory for the project '{project}' is: {working_dir}.") | |
# check for environment name | |
if not args.env: | |
env_name = project + "_env" | |
else: | |
env_name = args.env | |
# check for script | |
if not os.path.isfile(args.script): | |
logging.error(f"Script '{args.script}' does not exist.") | |
sys.exit(1) | |
# check for bash_run | |
if not os.path.isfile("bash_run.sh"): | |
logging.error("No bash_run provided.") | |
sys.exit(1) | |
# check the file for package install | |
requirements_ok = False | |
if os.path.isfile(args.requirements): | |
if os.path.isfile(args.requirements): | |
logging.info(f"The resuirements file is: {args.requirements}.") | |
requirements_ok = True | |
else: | |
logging.warning(f"The requirements file '{args.requirements}' does not exist. \ | |
The script may not be executed correctly.") | |
else: | |
logging.info("No requirements file provided.") | |
logging.warning(f"No requirements file provided. \ | |
The script may not be executed correctly.") | |
# tasks on remote host | |
try: | |
# Connect to remote host | |
logging.info(f"Connecting to host: {username}@{args.host}:{args.port}.") | |
client = paramiko.SSHClient() | |
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
client.load_system_host_keys() | |
client.connect(args.host, | |
port=args.port, | |
username=username, | |
password=password, | |
allow_agent=False, # Do not use SSH agent | |
look_for_keys=False # Do not use ssh keys | |
) | |
# open a SFTP session | |
session = client.get_transport().open_session() | |
session.invoke_shell() | |
while session.recv_ready(): | |
session.recv(1024) | |
def send_cmd(cmd, show=True): | |
# send command and receive output | |
rcv_timeout = 0.5 | |
interval_length = 0.1 | |
logging.info(f"Executing command: {cmd}") | |
session.send(cmd + "\n") | |
output = "" | |
while True: | |
if session.recv_ready(): | |
output += session.recv(1024).decode('utf-8').strip() | |
rcv_timeout -= interval_length | |
if rcv_timeout < 0: | |
break | |
else: | |
time.sleep(interval_length) | |
return output | |
# setup working directory | |
send_cmd(f"mkdir -p {working_dir}") | |
send_cmd(f"cd {working_dir}") | |
out = send_cmd(f"pwd") | |
print(f"Remote working directory: {out}, {working_dir}") | |
if out == working_dir: | |
logging.info(f"Remote working directory for the project '{project}' is: {out}") | |
else: | |
logging.error(f"Failed to setup remote working directory for the project '{project}'.") | |
sys.exit(1) | |
# set data directory | |
# check if data directory exists, if not, skip | |
remote_data_dir = os.path.join(working_dir, "data") | |
if os.path.isdir(args.data_dir): | |
# mount local data to remote working dir | |
logging.info(f"Mount local data '{data_dir}' \ | |
to remote working directory '{remote_data_dir}'.") | |
send_cmd(f"mkdir -p {data_dir}\n") | |
out = send_cmd(f"(mountpoint -q {remote_data_dir} && echo 'yes') || echo 'no'") | |
if out == 'yes': | |
# if yes umount first | |
logging.info(f"Umounting remote data directory '{remote_data_dir}'.") | |
send_cmd(f"umount {remote_data_dir}") | |
cmd = f"sshfs -o password_stdin \ | |
{osuser}@{myIP}:{current_dir}/{data_dir} {remote_data_dir} <<< '{ospwd}'" | |
send_cmd(cmd, show=False) | |
else: | |
logging.warning(f"Local data directory '{data_dir}' does not exist.") | |
# set output directory | |
out = send_cmd(f"(test -d {output_dir} && echo 'yes') || echo 'no'\n") | |
if out == 'no': | |
# if there is no output directory on remote point, create it | |
logging.info(f"Creating remote output directory: {output_dir}.") | |
cmd = f"mkdir -p {output_dir}\n" | |
send_cmd(cmd) | |
# mount remote output directory to local output dir | |
remote_output_dir = working_dir + "/" + output_dir | |
if not os.path.exists(args.output_dir): | |
logging.info(f"Creating local mount point: '{output_dir}'.") | |
res = subprocess.run(["mkdir", f"./{output_dir}"]) | |
# if the directory is not mounted yet, mount it, else, skip | |
if not os.path.ismount(output_dir): | |
cmd = f"sshfs -o password_stdin \ | |
{username}@{args.host}:{remote_output_dir} ./{output_dir} <<< '{password}'" | |
subprocess.call(cmd, shell=True) | |
else: | |
logging.info(f"The local directory '{output_dir}' is already mounted.") | |
# check if env exists, if not create it | |
out = send_cmd(f"(test -d {env_name} && echo 'yes') || echo 'no'\n") | |
if out == 'yes': | |
logging.info(f"The environment: {env_name} exists.") | |
else: | |
logging.info(f"Creating environment: {env_name}") | |
send_cmd(f"python -m venv {env_name}\n") | |
session.close() | |
# Setup sftp connection and transmit script and requirements | |
def send_file(_sftp, file): | |
# get file stats | |
mode = oct(os.stat(file).st_mode)[-3:] | |
_sftp.put(file, f"{working_dir}/{file}") | |
cmd = f"chmod -v {mode} {working_dir}/{file}" | |
_, _stdout, _stderr = client.exec_command(cmd) | |
print(_stdout, _stderr) | |
logging.info(f"Deploying script '{args.script}' to host: {args.host}.") | |
sftp = client.open_sftp() | |
send_file(sftp, args.script) | |
send_file(sftp, args.requirements) | |
send_file(sftp, "bash_run.sh") | |
sftp.close() | |
# run script | |
command = f"cd {working_dir}; " # change to working directory | |
command += f"source ./{env_name}/bin/activate; " # activate environment | |
command += f"pip install -r ./{args.requirements}; " # install requirements | |
command += f"./bash_run.sh" # run script | |
logging.info(f"Running: {command}") | |
_, _stdout, _stderr = client.exec_command(command, get_pty=True) | |
_stdout.channel.set_combine_stderr(True) | |
for line in _stdout: | |
logging.info(f"{line.strip()}") | |
logging.info("Script execution completed.") | |
# clean up | |
client.close() | |
sys.exit(0) | |
except paramiko.AuthenticationException as error: | |
logging.error(f"Authentication failed: {error}") | |
except paramiko.SSHException as error: | |
logging.error(f"SSH error: {error}") | |
except FileNotFoundError as error: | |
logging.error(f"File not found on local machine: {error}") | |
sys.exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment