Created
September 4, 2019 17:21
-
-
Save philerooski/af02cbcb58ae2d778efdf0491ccf2895 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Compute MD5 hashes of files in a local directory or in an S3 bucket. | |
Outputs a CSV file with columns `path` and `md5`. | |
When computing MD5 hashes of objects in an S3 bucket, `path` corresponds | |
to the S3 URI. | |
""" | |
import multiprocessing as mp | |
import boto3 as boto | |
import pandas as pd | |
import functools | |
import tempfile | |
import argparse | |
import hashlib | |
import os | |
MAX_KEYS = 1000 # maximum number of AWS objects to fetch per call | |
MAX_KEYS = min(MAX_KEYS, 1000) # This is an AWS imposed limit (DO NOT CHANGE!) | |
def read_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local-path", required=False, | |
help="Directory containing files to compute md5 hashes of.") | |
parser.add_argument("--s3-bucket", required=False, | |
help="S3 bucket containing S3 objects to compute md5 hashes of.") | |
parser.add_argument("--s3-key", required=False, | |
help="The key within the prespecified S3 bucket containing S3 " | |
"objects to compute md5 hashes of.") | |
parser.add_argument("--output-path", default="md5_manifest.csv") | |
parser.add_argument("--num-cores", required=False, | |
help="How many processes to use during download and/or hashing. " | |
"-1 specifies all cores.") | |
parser.add_argument("--profile", default=None, | |
help="AWS profile to use.") | |
args = parser.parse_args() | |
return(args) | |
def get_local_files(target_dir): | |
all_files = pd.DataFrame(columns = ["path"]) | |
for root, dirs, files in os.walk(target_dir): | |
for f in files: | |
row = pd.DataFrame({"path": [os.path.join(root, f)]}) | |
all_files = all_files.append(row, ignore_index = True) | |
return all_files | |
def get_s3_client(profile_name): | |
session = boto.Session(profile_name = profile_name) | |
client = session.client("s3") | |
return client | |
def next_s3_list_objects_batch(s3_client, s3_bucket, s3_key, start_after=""): | |
results = s3_client.list_objects_v2( | |
Bucket = s3_bucket, | |
MaxKeys = MAX_KEYS, | |
Prefix = s3_key, | |
StartAfter = start_after) | |
if 'Contents' in results: | |
return results['Contents'] | |
else: | |
return [] | |
def get_s3_object_list(s3_client, s3_bucket, s3_key): | |
previous_batch_length = MAX_KEYS | |
start_after = "" | |
all_objects = [] | |
while previous_batch_length == MAX_KEYS: | |
batch = next_s3_list_objects_batch( | |
s3_client = s3_client, | |
s3_bucket = s3_bucket, | |
s3_key = s3_key, | |
start_after = start_after) | |
all_objects += batch | |
previous_batch_length = len(batch) | |
start_after = batch[-1]['Key'] if len(batch) else None | |
return all_objects | |
def _hash_s3_object(s3_key, s3_bucket, s3_profile): | |
s3_client = get_s3_client(s3_profile) | |
hash_val = None | |
with tempfile.TemporaryFile() as temp_f: | |
try: | |
s3_client.download_fileobj( | |
Bucket = s3_bucket, | |
Key = s3_key, | |
Fileobj = temp_f) | |
temp_f.seek(0) | |
hash_val = md5sum(file_obj=temp_f) | |
except Exception as e: | |
print("could not download {} because {}".format(s3_key, str(e))) | |
return hash_val | |
def hash_s3_objects(object_list, s3_bucket, s3_profile, num_cores=None): | |
"""Returns pandas DataFrame with columns `path` and `md5`""" | |
s3_keys = [obj["Key"] for obj in object_list] | |
hash_s3_object = functools.partial( | |
_hash_s3_object, | |
s3_bucket = s3_bucket, | |
s3_profile = s3_profile) | |
if num_cores is not None: | |
pool = mp.Pool(num_cores) | |
md5_hashes = pool.map(hash_s3_object, s3_keys) | |
else: | |
md5_hashes = list(map(hash_s3_object, s3_keys)) | |
paths = ["s3://{}".format(os.path.join(s3_bucket, k)) for k in s3_keys] | |
result = pd.DataFrame({"path": paths, "md5": md5_hashes}) | |
return result | |
def _block_hash(file_obj, blocksize, hash=None): | |
if hash is None: | |
hash = hashlib.md5() | |
for block in iter(lambda: file_obj.read(blocksize), b""): | |
hash.update(block) | |
return hash | |
def md5sum(filename=None, file_obj=None, blocksize=50*1024**2): | |
if file_obj is not None: | |
hash = _block_hash( | |
file_obj = file_obj, | |
blocksize = blocksize) | |
elif filename is not None: | |
with open(filename, "rb") as file_obj: | |
hash = _block_hash( | |
file_obj = file_obj, | |
blocksize = blocksize) | |
else: | |
raise TypeError("Either filename or file_obj must be set.") | |
return hash.hexdigest() | |
def main(): | |
args = read_args() | |
args.num_cores = mp.cpu_count() if args.num_cores == -1 else args.num_cores | |
if args.local_path is not None: | |
all_files = get_local_files(target_dir=args.local_path) | |
all_file['md5'] = all_files.path.apply(md5sum) | |
all_files.to_csv(args.output_path, index=False) | |
elif args.s3_key is not None and args.s3_bucket is not None: | |
s3_client = get_s3_client(profile_name=args.profile) | |
s3_object_list = get_s3_object_list( | |
s3_client = s3_client, | |
s3_bucket = args.s3_bucket, | |
s3_key = args.s3_key) | |
s3_hashes = hash_s3_objects( | |
object_list=s3_object_list, | |
s3_bucket = args.s3_bucket, | |
s3_profile = args.profile) | |
s3_hashes.to_csv(args.output_path, index=False) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment