Last active
November 13, 2023 08:45
-
-
Save zhiweio/ba587ff05d5bcad7963521e6bd64232d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import re | |
import socket | |
import stat | |
from datetime import datetime | |
import paramiko | |
import pytz | |
import redo | |
from paramiko.ssh_exception import AuthenticationException, SSHException | |
# set default timeout to 30 seconds | |
REQUEST_TIMEOUT = 30 | |
LOG = logging.getLogger(__name__) | |
def handle_backoff(): | |
LOG.warning("SSH Connection closed unexpectedly. Waiting seconds and retrying...") | |
class SFTPConnection: | |
def __init__( | |
self, | |
host, | |
username, | |
password=None, | |
private_key_file=None, | |
port=None, | |
timeout=REQUEST_TIMEOUT, | |
): | |
self.host = host | |
self.username = username | |
self.password = password | |
self.port = int(port) or 22 | |
self.__active_connection = False | |
self.transport = None | |
self.key = None | |
if private_key_file: | |
key_path = os.path.expanduser(private_key_file) | |
self.key = paramiko.RSAKey.from_private_key_file(key_path) | |
if timeout and float(timeout): | |
# set the request timeout for the requests | |
# if value is 0,"0", "" or None then it will set default to default to 300.0 seconds if not passed in config. | |
self.request_timeout = float(timeout) | |
else: | |
# set the default timeout of 300 seconds | |
self.request_timeout = REQUEST_TIMEOUT | |
# If connection is snapped during connect flow, retry up to a | |
# minute for SSH connection to succeed. 2^6 + 2^5 + ... | |
@redo.retriable( | |
attempts=5, | |
sleeptime=60, | |
max_sleeptime=5 * 60, | |
retry_exceptions=(EOFError,), | |
cleanup=handle_backoff, | |
) | |
def __try_connect(self): | |
if not self.__active_connection: | |
try: | |
self.transport = paramiko.Transport((self.host, self.port)) | |
self.transport.use_compression(True) | |
self.transport.connect( | |
username=self.username, | |
password=self.password, | |
hostkey=None, | |
pkey=self.key, | |
) | |
self.sftp = paramiko.SFTPClient.from_transport(self.transport) | |
except (AuthenticationException, SSHException) as ex: | |
LOG.warning(f"Connect to SFTP error {ex}, will retry...") | |
self.transport.close() | |
self.transport = paramiko.Transport((self.host, self.port)) | |
self.transport.use_compression(True) | |
self.transport.connect( | |
username=self.username, | |
password=self.password, | |
hostkey=None, | |
pkey=None, | |
) | |
self.sftp = paramiko.SFTPClient.from_transport(self.transport) | |
self.__active_connection = True | |
# get 'socket' to set the timeout | |
socket = self.sftp.get_channel() | |
# set request timeout | |
socket.settimeout(self.request_timeout) | |
@property | |
def sftp(self): | |
self.__try_connect() | |
return self.__sftp | |
@sftp.setter | |
def sftp(self, sftp): | |
self.__sftp = sftp | |
def __enter__(self): | |
self.__try_connect() | |
return self | |
def __del__(self): | |
"""Clean up the socket when this class gets garbage collected.""" | |
self.close() | |
def __exit__(self): | |
"""Clean up the socket when this class gets garbage collected.""" | |
self.close() | |
def close(self): | |
if self.__active_connection: | |
self.sftp.close() | |
self.transport.close() | |
self.__active_connection = False | |
@redo.retriable( | |
attempts=6, | |
sleeptime=10, | |
max_sleeptime=60, | |
retry_exceptions=(socket.timeout,), | |
) | |
def ch_mkdir(self, directory): | |
if directory == "/": | |
# absolute path so change directory to root | |
self.sftp.chdir("/") | |
return | |
if directory == "": | |
# top-level relative directory must exist | |
return | |
try: | |
self.sftp.chdir(directory) # subdirectory exists | |
except (FileNotFoundError, IOError): | |
dirname, basename = os.path.split(directory.rstrip("/")) | |
self.ch_mkdir(dirname) # make parent directories | |
self.sftp.mkdir(basename) # subdirectory missing, so created it | |
self.sftp.chdir(basename) | |
return True | |
@redo.retriable( | |
attempts=6, | |
sleeptime=10, | |
max_sleeptime=60, | |
retry_exceptions=(socket.timeout,), | |
) | |
def put(self, localpath, remotepath, callback=None, confirm=True): | |
return self.sftp.put(localpath, remotepath, callback=callback, confirm=confirm) | |
@redo.retriable( | |
attempts=6, | |
sleeptime=10, | |
max_sleeptime=60, | |
retry_exceptions=(socket.timeout,), | |
) | |
def get_files_by_prefix(self, prefix): | |
""" | |
Accesses the underlying file system and gets all files that match "prefix", in this case, a directory path. | |
Returns a list of filepaths from the root. | |
""" | |
files = [] | |
if prefix is None or prefix == "": | |
prefix = "." | |
try: | |
result = self.sftp.listdir_attr(prefix) | |
except FileNotFoundError as e: | |
raise Exception("Directory '{}' does not exist".format(prefix)) from e | |
is_empty = lambda a: a.st_size == 0 | |
is_directory = lambda a: stat.S_ISDIR(a.st_mode) | |
for file_attr in result: | |
# NB: This only looks at the immediate level beneath the prefix directory | |
if is_directory(file_attr): | |
files += self.get_files_by_prefix(prefix + "/" + file_attr.filename) | |
else: | |
if is_empty(file_attr): | |
continue | |
last_modified = file_attr.st_mtime | |
if last_modified is None: | |
LOG.warning( | |
"Cannot read m_time for file %s, defaulting to current epoch time", | |
os.path.join(prefix, file_attr.filename), | |
) | |
last_modified = datetime.utcnow().timestamp() | |
# NB: SFTP specifies path characters to be '/' | |
# https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6 | |
files.append( | |
{ | |
"filepath": prefix + "/" + file_attr.filename, | |
"last_modified": datetime.utcfromtimestamp( | |
last_modified | |
).replace(tzinfo=pytz.UTC), | |
} | |
) | |
return files | |
def get_files(self, prefix, search_pattern, modified_since=None): | |
files = self.get_files_by_prefix(prefix) | |
if files: | |
LOG.info('Found %s files in "%s"', len(files), prefix) | |
else: | |
LOG.warning('Found no files on specified SFTP server at "%s"', prefix) | |
matching_files = self.get_files_matching_pattern(files, search_pattern) | |
if matching_files: | |
LOG.info( | |
'Found %s files in "%s" matching "%s"', | |
len(matching_files), | |
prefix, | |
search_pattern, | |
) | |
else: | |
LOG.warning( | |
'Found no files on specified SFTP server at "%s" matching "%s"', | |
prefix, | |
search_pattern, | |
) | |
for f in matching_files: | |
LOG.info("Found file: %s", f["filepath"]) | |
if modified_since is not None: | |
matching_files = [ | |
f for f in matching_files if f["last_modified"] > modified_since | |
] | |
# sort files in increasing order of "last_modified" | |
sorted_files = sorted( | |
matching_files, key=lambda x: (x["last_modified"]).timestamp() | |
) | |
return sorted_files | |
# retry 5 times for timeout error | |
@redo.retriable( | |
attempts=6, | |
sleeptime=60, | |
max_sleeptime=300, | |
retry_exceptions=(socket.timeout,), | |
) | |
def get_file_handle(self, f): | |
"""Takes a file dict {"filepath": "...", "last_modified": "..."} | |
-> returns a handle to the file. | |
-> raises error with appropriate logger message""" | |
try: | |
return self.sftp.open(f["filepath"], "rb") | |
except OSError as e: | |
if "Permission denied" in str(e): | |
LOG.warning( | |
"Skipping %s file because you do not have enough permissions.", | |
f["filepath"], | |
) | |
else: | |
LOG.warning( | |
"Skipping %s file because it is unable to be read.", f["filepath"] | |
) | |
raise | |
def get_files_matching_pattern(self, files, pattern): | |
"""Takes a file dict {"filepath": "...", "last_modified": "..."} and a regex pattern string, and returns files matching that pattern.""" | |
matcher = re.compile(pattern) | |
return [f for f in files if matcher.search(f["filepath"])] | |
def connection(config): | |
return SFTPConnection( | |
config["host"], | |
config["username"], | |
password=config.get("password"), | |
private_key_file=config.get("private_key_file"), | |
port=config.get("port"), | |
timeout=config.get("request_timeout"), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment