Skip to content

Instantly share code, notes, and snippets.

@zhiweio
Last active November 13, 2023 08:45
Show Gist options
  • Save zhiweio/ba587ff05d5bcad7963521e6bd64232d to your computer and use it in GitHub Desktop.
Save zhiweio/ba587ff05d5bcad7963521e6bd64232d to your computer and use it in GitHub Desktop.
import logging
import os
import re
import socket
import stat
from datetime import datetime
import paramiko
import pytz
import redo
from paramiko.ssh_exception import AuthenticationException, SSHException
# set default timeout to 30 seconds
REQUEST_TIMEOUT = 30
LOG = logging.getLogger(__name__)
def handle_backoff():
LOG.warning("SSH Connection closed unexpectedly. Waiting seconds and retrying...")
class SFTPConnection:
def __init__(
self,
host,
username,
password=None,
private_key_file=None,
port=None,
timeout=REQUEST_TIMEOUT,
):
self.host = host
self.username = username
self.password = password
self.port = int(port) or 22
self.__active_connection = False
self.transport = None
self.key = None
if private_key_file:
key_path = os.path.expanduser(private_key_file)
self.key = paramiko.RSAKey.from_private_key_file(key_path)
if timeout and float(timeout):
# set the request timeout for the requests
# if value is 0,"0", "" or None then it will set default to default to 300.0 seconds if not passed in config.
self.request_timeout = float(timeout)
else:
# set the default timeout of 300 seconds
self.request_timeout = REQUEST_TIMEOUT
# If connection is snapped during connect flow, retry up to a
# minute for SSH connection to succeed. 2^6 + 2^5 + ...
@redo.retriable(
attempts=5,
sleeptime=60,
max_sleeptime=5 * 60,
retry_exceptions=(EOFError,),
cleanup=handle_backoff,
)
def __try_connect(self):
if not self.__active_connection:
try:
self.transport = paramiko.Transport((self.host, self.port))
self.transport.use_compression(True)
self.transport.connect(
username=self.username,
password=self.password,
hostkey=None,
pkey=self.key,
)
self.sftp = paramiko.SFTPClient.from_transport(self.transport)
except (AuthenticationException, SSHException) as ex:
LOG.warning(f"Connect to SFTP error {ex}, will retry...")
self.transport.close()
self.transport = paramiko.Transport((self.host, self.port))
self.transport.use_compression(True)
self.transport.connect(
username=self.username,
password=self.password,
hostkey=None,
pkey=None,
)
self.sftp = paramiko.SFTPClient.from_transport(self.transport)
self.__active_connection = True
# get 'socket' to set the timeout
socket = self.sftp.get_channel()
# set request timeout
socket.settimeout(self.request_timeout)
@property
def sftp(self):
self.__try_connect()
return self.__sftp
@sftp.setter
def sftp(self, sftp):
self.__sftp = sftp
def __enter__(self):
self.__try_connect()
return self
def __del__(self):
"""Clean up the socket when this class gets garbage collected."""
self.close()
def __exit__(self):
"""Clean up the socket when this class gets garbage collected."""
self.close()
def close(self):
if self.__active_connection:
self.sftp.close()
self.transport.close()
self.__active_connection = False
@redo.retriable(
attempts=6,
sleeptime=10,
max_sleeptime=60,
retry_exceptions=(socket.timeout,),
)
def ch_mkdir(self, directory):
if directory == "/":
# absolute path so change directory to root
self.sftp.chdir("/")
return
if directory == "":
# top-level relative directory must exist
return
try:
self.sftp.chdir(directory) # subdirectory exists
except (FileNotFoundError, IOError):
dirname, basename = os.path.split(directory.rstrip("/"))
self.ch_mkdir(dirname) # make parent directories
self.sftp.mkdir(basename) # subdirectory missing, so created it
self.sftp.chdir(basename)
return True
@redo.retriable(
attempts=6,
sleeptime=10,
max_sleeptime=60,
retry_exceptions=(socket.timeout,),
)
def put(self, localpath, remotepath, callback=None, confirm=True):
return self.sftp.put(localpath, remotepath, callback=callback, confirm=confirm)
@redo.retriable(
attempts=6,
sleeptime=10,
max_sleeptime=60,
retry_exceptions=(socket.timeout,),
)
def get_files_by_prefix(self, prefix):
"""
Accesses the underlying file system and gets all files that match "prefix", in this case, a directory path.
Returns a list of filepaths from the root.
"""
files = []
if prefix is None or prefix == "":
prefix = "."
try:
result = self.sftp.listdir_attr(prefix)
except FileNotFoundError as e:
raise Exception("Directory '{}' does not exist".format(prefix)) from e
is_empty = lambda a: a.st_size == 0
is_directory = lambda a: stat.S_ISDIR(a.st_mode)
for file_attr in result:
# NB: This only looks at the immediate level beneath the prefix directory
if is_directory(file_attr):
files += self.get_files_by_prefix(prefix + "/" + file_attr.filename)
else:
if is_empty(file_attr):
continue
last_modified = file_attr.st_mtime
if last_modified is None:
LOG.warning(
"Cannot read m_time for file %s, defaulting to current epoch time",
os.path.join(prefix, file_attr.filename),
)
last_modified = datetime.utcnow().timestamp()
# NB: SFTP specifies path characters to be '/'
# https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6
files.append(
{
"filepath": prefix + "/" + file_attr.filename,
"last_modified": datetime.utcfromtimestamp(
last_modified
).replace(tzinfo=pytz.UTC),
}
)
return files
def get_files(self, prefix, search_pattern, modified_since=None):
files = self.get_files_by_prefix(prefix)
if files:
LOG.info('Found %s files in "%s"', len(files), prefix)
else:
LOG.warning('Found no files on specified SFTP server at "%s"', prefix)
matching_files = self.get_files_matching_pattern(files, search_pattern)
if matching_files:
LOG.info(
'Found %s files in "%s" matching "%s"',
len(matching_files),
prefix,
search_pattern,
)
else:
LOG.warning(
'Found no files on specified SFTP server at "%s" matching "%s"',
prefix,
search_pattern,
)
for f in matching_files:
LOG.info("Found file: %s", f["filepath"])
if modified_since is not None:
matching_files = [
f for f in matching_files if f["last_modified"] > modified_since
]
# sort files in increasing order of "last_modified"
sorted_files = sorted(
matching_files, key=lambda x: (x["last_modified"]).timestamp()
)
return sorted_files
# retry 5 times for timeout error
@redo.retriable(
attempts=6,
sleeptime=60,
max_sleeptime=300,
retry_exceptions=(socket.timeout,),
)
def get_file_handle(self, f):
"""Takes a file dict {"filepath": "...", "last_modified": "..."}
-> returns a handle to the file.
-> raises error with appropriate logger message"""
try:
return self.sftp.open(f["filepath"], "rb")
except OSError as e:
if "Permission denied" in str(e):
LOG.warning(
"Skipping %s file because you do not have enough permissions.",
f["filepath"],
)
else:
LOG.warning(
"Skipping %s file because it is unable to be read.", f["filepath"]
)
raise
def get_files_matching_pattern(self, files, pattern):
"""Takes a file dict {"filepath": "...", "last_modified": "..."} and a regex pattern string, and returns files matching that pattern."""
matcher = re.compile(pattern)
return [f for f in files if matcher.search(f["filepath"])]
def connection(config):
return SFTPConnection(
config["host"],
config["username"],
password=config.get("password"),
private_key_file=config.get("private_key_file"),
port=config.get("port"),
timeout=config.get("request_timeout"),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment