Created
November 6, 2018 03:34
-
-
Save JSchulte01/9ba3f56cb53e2fb447445d9ba6537912 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Provides methods around syncing and usage of AWS s3 buckets as local caches rather than individually | |
downloading every file""" | |
import os | |
import shutil | |
import boto3 as boto | |
import multiprocessing | |
import copy | |
import hashlib | |
import logging | |
log = logging.getLogger(__name__) | |
class LocalObjectCache: | |
"""Provides a local cache of an S3 bucket on disk, with the ability to sync up to the latest version of all files""" | |
_DEFAULT_PATH = '/tmp/local_object_store/' | |
def __init__(self, bucket_name, prefix='', path=None): | |
"""Init Method | |
:param bucket_name: str, the name of the S3 bucket | |
:param prefix: str, the prefix up to which you want to sync | |
:param path: (optional, str) a path to store the local files | |
""" | |
self.bucket_name = bucket_name | |
self.prefix = prefix | |
if not path: | |
path = self._DEFAULT_PATH + self.bucket_name + '/' | |
self.path = path | |
os.makedirs(path, exist_ok=True) | |
s3 = boto.resource('s3') | |
self.bucket = s3.Bucket(self.bucket_name) | |
def __enter__(self): | |
"""Provides a context manager which will open but not sync, then delete the cache on exit""" | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
"""Provides a context manager which will open but not sync, then delete the cache on exit""" | |
self.close() | |
def __getstate__(self): | |
# Require to pickle and un-pickle the self object between multiprocessing pools | |
out = copy.copy(self.__dict__) | |
out['bucket'] = None | |
return out | |
def __setstate__(self, d): | |
# Require to pickle and un-pickle the self object between multiprocessing pools | |
s3 = boto.resource('s3') | |
d['bucket'] = s3.Bucket(d['bucket_name']) | |
self.__dict__ = d | |
def get_path(self, key): | |
"""Returns the local file storage path for a given file key""" | |
return os.path.join(self.path, self.prefix, key) | |
@staticmethod | |
def calculate_s3_etag(file, chunk_size=8 * 1024 * 1024): | |
"""Calculates the S3 custom e-tag (a specially formatted MD5 hash)""" | |
md5s = [] | |
while True: | |
data = file.read(chunk_size) | |
if not data: | |
break | |
md5s.append(hashlib.md5(data)) | |
if len(md5s) == 1: | |
return '"{}"'.format(md5s[0].hexdigest()) | |
digests = b''.join(m.digest() for m in md5s) | |
digests_md5 = hashlib.md5(digests) | |
return '"{}-{}"'.format(digests_md5.hexdigest(), len(md5s)) | |
def _get_obj(self, key, tag): | |
"""Downloads an object at key to file path, checking to see if an existing file matches the current hash""" | |
path = os.path.join(self.path, key) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
dl_flag = True | |
try: | |
f = open(path, 'rb') | |
if tag == self.calculate_s3_etag(f): | |
log.info('Cache Hit') | |
dl_flag = False | |
f.close() | |
except FileNotFoundError as e: | |
pass | |
if dl_flag: | |
log.info('Cache Miss') | |
self.bucket.download_file(key, path) | |
def sync(self): | |
"""Syncs the local and remote S3 copies""" | |
pool = multiprocessing.Pool() | |
keys = [(obj.key, obj.e_tag) for obj in self.bucket.objects.filter(Prefix=self.prefix)] | |
pool.starmap(self._get_obj, keys) | |
def close(self): | |
"""Deletes all local files""" | |
shutil.rmtree(self.path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for sharing, really got me going. One issue I've found is that the calculate_s3_etag does not match S3 if the file size is zero. With the one file I've seen so far, simply removing
-0
from the end of the generated etag them will match S3.