Created
May 12, 2020 19:26
-
-
Save Quentin-M/c5afec9a79886f7f7c0e02c351d29301 to your computer and use it in GitHub Desktop.
Decrypts a set of S3 files resulting from an DB Activity Streams using an LRU to save time/$ (Local dev / Python version - Also have a Golang version to run as Lambda in Kinesis Stream)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import zlib | |
import re | |
import argparse | |
import os | |
import glob | |
import locale | |
from json import JSONDecoder, JSONDecodeError | |
import aws_encryption_sdk | |
from aws_encryption_sdk.internal.crypto import WrappingKey | |
from aws_encryption_sdk.key_providers.raw import RawMasterKeyProvider | |
from aws_encryption_sdk.identifiers import WrappingAlgorithm, EncryptionKeyType | |
import boto3 | |
from lru import LRU | |
NOT_WHITESPACE = re.compile(r"[^\s]") | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument("--path", type=str, required=True, help="Path to the activity directory to parse") | |
parser.add_argument("--region", type=str, default="eu-west-1", help="AWS Region") | |
parser.add_argument("--data-stream-id", type=str, required=True, help="Kinesis Data Stream ID (e.g. 'cluster-AHDJDNADNJQWE')") | |
args = parser.parse_args() | |
def stream_json(file_obj, buf_size=1024, decoder=JSONDecoder()): | |
buf = "" | |
ex = None | |
while True: | |
block = file_obj.read(buf_size) | |
if not block: | |
break | |
buf += block | |
pos = 0 | |
while True: | |
match = NOT_WHITESPACE.search(buf, pos) | |
if not match: | |
break | |
pos = match.start() | |
try: | |
obj, pos = decoder.raw_decode(buf, pos) | |
except JSONDecodeError as e: | |
ex = e | |
break | |
else: | |
ex = None | |
yield obj | |
buf = buf[pos:] | |
if ex is not None: | |
raise ex | |
class MyRawMasterKeyProvider(RawMasterKeyProvider): | |
provider_id = "BC" | |
def __new__(cls, *args, **kwargs): | |
obj = super(RawMasterKeyProvider, cls).__new__(cls) | |
return obj | |
def __init__(self, plain_key): | |
RawMasterKeyProvider.__init__(self) | |
self.wrapping_key = WrappingKey(wrapping_algorithm=WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, | |
wrapping_key=plain_key, wrapping_key_type=EncryptionKeyType.SYMMETRIC) | |
def _get_raw_key(self, key_id): | |
return self.wrapping_key | |
def decrypt_key(encrypted_data_key, kms, keys_cache): | |
if not hasattr(decrypt_key, "hit"): | |
decrypt_key.hit = 0 | |
decrypt_key.miss = 0 | |
if keys_cache.has_key(encrypted_data_key): | |
decrypt_key.hit = decrypt_key.hit + 1 | |
else: | |
decrypt_key.miss = decrypt_key.miss + 1 | |
data_key_decrypted = kms.decrypt(CiphertextBlob=base64.b64decode(encrypted_data_key), EncryptionContext={'aws:rds:dbc-id': args.data_stream_id}) | |
data_key_decrypted = data_key_decrypted['Plaintext'] | |
keys_cache[encrypted_data_key] = data_key_decrypted | |
return keys_cache[encrypted_data_key] | |
def decrypt_payload(encrypted_payload, encrypted_data_key, kms, keys_cache): | |
key_provider = MyRawMasterKeyProvider(decrypt_key(encrypted_data_key, kms, keys_cache)) | |
key_provider.add_master_key("DataKey") | |
mat_provider = aws_encryption_sdk.DefaultCryptoMaterialsManager(master_key_provider=key_provider) | |
decrypted_plaintext, header = aws_encryption_sdk.decrypt(source=base64.b64decode(encrypted_payload), materials_manager=mat_provider) | |
return zlib.decompress(decrypted_plaintext, zlib.MAX_WBITS + 1).decode('utf-8') | |
def main(): | |
locale.setlocale(locale.LC_ALL, 'en_US') | |
session = boto3.session.Session() | |
kms = session.client('kms', region_name=args.region) | |
keys_cache = LRU(1024) | |
files_counter = 0 | |
lines_counter = 0 | |
for fp in glob.glob(os.path.join(args.path, "**/*.json"), recursive=True): | |
print("Decrypting %s ..." % fp) | |
with open(fp, "r") as i: | |
files_counter = files_counter + 1 | |
with open(fp + ".decrypted", "w") as o: | |
for j in stream_json(i): | |
lines_counter = lines_counter + 1 | |
print(decrypt_payload(j['databaseActivityEvents'], j['key'], kms, keys_cache), file=o) | |
os.remove(fp) | |
os.rename(fp + ".decrypted", fp) | |
print("Files decrypted: %s" % locale.format_string("%d", files_counter, grouping=True)) | |
print("Lines decrypted: %s" % locale.format_string("%d", lines_counter, grouping=True)) | |
print("KMS Decrypt (cached): %s" % locale.format_string("%d", decrypt_key.hit, grouping=True)) | |
print("KMS Decrypt (called): %s" % locale.format_string("%d", decrypt_key.miss, grouping=True)) | |
print("KMS Cache Size: %d / %d" % (len(keys_cache.items()), 1024)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment