Skip to content

Instantly share code, notes, and snippets.

@Quentin-M
Created May 12, 2020 19:26
Show Gist options
  • Save Quentin-M/c5afec9a79886f7f7c0e02c351d29301 to your computer and use it in GitHub Desktop.
Save Quentin-M/c5afec9a79886f7f7c0e02c351d29301 to your computer and use it in GitHub Desktop.
Decrypts a set of S3 files resulting from an DB Activity Streams using an LRU to save time/$ (Local dev / Python version - Also have a Golang version to run as Lambda in Kinesis Stream)
import base64
import zlib
import re
import argparse
import os
import glob
import locale
from json import JSONDecoder, JSONDecodeError
import aws_encryption_sdk
from aws_encryption_sdk.internal.crypto import WrappingKey
from aws_encryption_sdk.key_providers.raw import RawMasterKeyProvider
from aws_encryption_sdk.identifiers import WrappingAlgorithm, EncryptionKeyType
import boto3
from lru import LRU
NOT_WHITESPACE = re.compile(r"[^\s]")
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument("--path", type=str, required=True, help="Path to the activity directory to parse")
parser.add_argument("--region", type=str, default="eu-west-1", help="AWS Region")
parser.add_argument("--data-stream-id", type=str, required=True, help="Kinesis Data Stream ID (e.g. 'cluster-AHDJDNADNJQWE')")
args = parser.parse_args()
def stream_json(file_obj, buf_size=1024, decoder=JSONDecoder()):
buf = ""
ex = None
while True:
block = file_obj.read(buf_size)
if not block:
break
buf += block
pos = 0
while True:
match = NOT_WHITESPACE.search(buf, pos)
if not match:
break
pos = match.start()
try:
obj, pos = decoder.raw_decode(buf, pos)
except JSONDecodeError as e:
ex = e
break
else:
ex = None
yield obj
buf = buf[pos:]
if ex is not None:
raise ex
class MyRawMasterKeyProvider(RawMasterKeyProvider):
provider_id = "BC"
def __new__(cls, *args, **kwargs):
obj = super(RawMasterKeyProvider, cls).__new__(cls)
return obj
def __init__(self, plain_key):
RawMasterKeyProvider.__init__(self)
self.wrapping_key = WrappingKey(wrapping_algorithm=WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING,
wrapping_key=plain_key, wrapping_key_type=EncryptionKeyType.SYMMETRIC)
def _get_raw_key(self, key_id):
return self.wrapping_key
def decrypt_key(encrypted_data_key, kms, keys_cache):
if not hasattr(decrypt_key, "hit"):
decrypt_key.hit = 0
decrypt_key.miss = 0
if keys_cache.has_key(encrypted_data_key):
decrypt_key.hit = decrypt_key.hit + 1
else:
decrypt_key.miss = decrypt_key.miss + 1
data_key_decrypted = kms.decrypt(CiphertextBlob=base64.b64decode(encrypted_data_key), EncryptionContext={'aws:rds:dbc-id': args.data_stream_id})
data_key_decrypted = data_key_decrypted['Plaintext']
keys_cache[encrypted_data_key] = data_key_decrypted
return keys_cache[encrypted_data_key]
def decrypt_payload(encrypted_payload, encrypted_data_key, kms, keys_cache):
key_provider = MyRawMasterKeyProvider(decrypt_key(encrypted_data_key, kms, keys_cache))
key_provider.add_master_key("DataKey")
mat_provider = aws_encryption_sdk.DefaultCryptoMaterialsManager(master_key_provider=key_provider)
decrypted_plaintext, header = aws_encryption_sdk.decrypt(source=base64.b64decode(encrypted_payload), materials_manager=mat_provider)
return zlib.decompress(decrypted_plaintext, zlib.MAX_WBITS + 1).decode('utf-8')
def main():
locale.setlocale(locale.LC_ALL, 'en_US')
session = boto3.session.Session()
kms = session.client('kms', region_name=args.region)
keys_cache = LRU(1024)
files_counter = 0
lines_counter = 0
for fp in glob.glob(os.path.join(args.path, "**/*.json"), recursive=True):
print("Decrypting %s ..." % fp)
with open(fp, "r") as i:
files_counter = files_counter + 1
with open(fp + ".decrypted", "w") as o:
for j in stream_json(i):
lines_counter = lines_counter + 1
print(decrypt_payload(j['databaseActivityEvents'], j['key'], kms, keys_cache), file=o)
os.remove(fp)
os.rename(fp + ".decrypted", fp)
print("Files decrypted: %s" % locale.format_string("%d", files_counter, grouping=True))
print("Lines decrypted: %s" % locale.format_string("%d", lines_counter, grouping=True))
print("KMS Decrypt (cached): %s" % locale.format_string("%d", decrypt_key.hit, grouping=True))
print("KMS Decrypt (called): %s" % locale.format_string("%d", decrypt_key.miss, grouping=True))
print("KMS Cache Size: %d / %d" % (len(keys_cache.items()), 1024))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment