Created
March 29, 2017 19:18
-
-
Save r0yfire/d82f4f0a1b604db3b05e8f9e346a6459 to your computer and use it in GitHub Desktop.
Massively parallel copy S3 bucket using pyspark.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from operator import add | |
import concurrent | |
from concurrent.futures import ThreadPoolExecutor | |
from boto.s3.connection import S3Connection | |
from pyspark import SparkContext | |
def computeTargets(bucketName, prefix=""): | |
s3 = S3Connection() | |
return [key.name for key in s3.get_bucket(bucketName).list(prefix=prefix) if not key.name.startswith("logs")] | |
def getBucketRDD(bucketName, keyPrefixes="0123456789abcdefghijklmnopqrstuvwxyz-_", partitions=1000): | |
prefixRDD = sc.parallelize(keyPrefixes) | |
prefixRDD.setName("Bucket Key Prefixes") | |
keysNamesRDD = prefixRDD.flatMap(lambda pfx: computeTargets(bucketName, pfx)) | |
keysNamesRDD.setName("Keys in bucket %s" % bucketName) | |
print("{} keys are in this bucket.".format(keysNamesRDD.count())) | |
filteredKeyNamesRDD = keysNamesRDD.filter(lambda name: not name.startswith("logs")) | |
filteredKeyNamesRDD.setName("Keys in bucket %s (filtered)" % bucketName) | |
print("{} keys are in this bucket and not logs.".format(filteredKeyNamesRDD.count())) | |
return filteredKeyNamesRDD.repartition(partitions) | |
def copyTarget(keyName, sourceBucketName=sourceBucketName, destBucketName=destinationBucketName, attemptCount=5): | |
s3 = S3Connection() | |
sourceKey = s3.get_bucket(sourceBucketName).get_key(keyName) | |
destBucket = s3.get_bucket(destBucketName) | |
if destBucket.get_key(keyName) is None: | |
for attemptNumber in range(attemptCount): | |
try: | |
sourceKey.copy(destBucket, keyName) | |
return True | |
except Exception as e: | |
print("Great Failure: {0} {1}".format(keyName, e)) | |
# Out of attmempts. Fail. | |
return False | |
else: | |
print("Key already exists in remote bucket. Skipping.") | |
return True | |
def copyKeyList(keyNameList): | |
failures = [] | |
with ThreadPoolExecutor(max_workers=10) as executor: | |
# Start the load operations and mark each future with its URL | |
futureToKeyName = {executor.submit(copyTarget, keyName): keyName for keyName in keyNameList} | |
for future in concurrent.futures.as_completed(futureToKeyName): | |
keyName = futureToKeyName[future] | |
if not future.result(): | |
failures.append(keyName) | |
return failures | |
def bucketToBucket(srcBucketName, dstBucketName): | |
srcListRDD = getBucketRDD(srcBucketName) | |
dstListRDD = getBucketRDD(dstBucketName) | |
srcListRDD.setName("Source Bucket") | |
dstListRDD.setName("Desitnation Bucket") | |
toCopyRDD = srcListRDD.subtract(dstListRDD) | |
toCopyRDD.setName("Keys to move") | |
groupedCopyListsRDD = toCopyRDD.repartition(1000) | |
copiedGroupsRDD = groupedCopyListsRDD.mapPartitions(copyKeyList) | |
return copiedGroupsRDD.reduce(add) | |
def getKeyMD5Size(keyName, bucketName): | |
# Return the integer MD5 hash, and integer size of a key in bytes | |
s3 = S3Connection() | |
k = s3.get_bucket(bucketName).get_key(keyName) | |
assert k is not None, "Not a key" | |
unquoted = k.etag[1:-1] # AWS returns the md5 for the file including quotes in the string | |
return int(unquoted, 16), int(k.size) | |
def reduceMD5Size(md5SizeTuple1, md5SizeTuple2): | |
# MRG NOTE: This hash is not at all secure against | |
# intentional tampering, the design is only as a | |
# associative/communitive copy-verification mechanism | |
md5Int1, size1 = md5SizeTuple1 | |
md5Int2, size2 = md5SizeTuple2 | |
return (md5Int1 ^ md5Int2), size1 + size2 | |
def bucketMerkelHash(bucketName): | |
bucketRDD = getBucketRDD(bucketName) | |
md5RDD = bucketRDD.map(lambda keyName: getKeyMD5Size(keyName, bucketName)) | |
md5RDD.setName("Key MD5's") | |
collapsedHash, size = md5RDD.reduce(reduceMD5Size) | |
print(collapsedHash, size) | |
print("Bucket contains {0: f} GB hash {1: #x}".format(size / 1e9, collapsedHash)) | |
return collapsedHash, size | |
sourceBucketName = "PLACEHOLDER" | |
destinationBucketName = "PLACEHOLDER" | |
sc = SparkContext(appName="image-copy") | |
bucketToBucket(sourceBucketName, destinationBucketName) | |
bucketMerkelHash(sourceBucketName) | |
bucketMerkelHash(destinationBucketName) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment