Created
July 27, 2019 03:18
-
-
Save danifus/73d258df243bbb386c1dd64c0888cddf to your computer and use it in GitHub Desktop.
Implementing Winzip AES encryption / decryption with zipfile refactor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import zipfile | |
import zipfile_aes | |
secret_password = b'lost art of keeping a secret' | |
with zipfile_aes.AESZipFile('new_test.zip', | |
'w', | |
compression=zipfile.ZIP_LZMA, | |
encryption=zipfile_aes.WZ_AES) as zf: | |
zf.setpassword(secret_password) | |
zf.writestr('test.txt', "What ever you do, don't tell anyone!") | |
with zipfile_aes.AESZipFile('new_test.zip') as zf: | |
zf.setpassword(secret_password) | |
my_secrets = zf.read('test.txt') | |
print(my_secrets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
# requires pip install pycryptodomex | |
from Cryptodome.Protocol.KDF import PBKDF2 | |
from Cryptodome.Cipher import AES | |
from Cryptodome.Hash import HMAC | |
from Cryptodome.Hash.SHA1 import SHA1Hash | |
from Cryptodome.Util import Counter | |
from Cryptodome import Random | |
from zipfile import ( | |
ZIP_BZIP2, BadZipFile, BaseDecrypter, ZipFile, ZipInfo, ZipExtFile, | |
_ZipWriteFile, crc32, _MASK_ENCRYPTED, CRCZipDecrypter, | |
) | |
WZ_AES = 'WZ_AES' | |
WZ_AES_COMPRESS_TYPE = 99 | |
WZ_AES_V1 = 0x0001 | |
WZ_AES_V2 = 0x0002 | |
WZ_AES_VENDOR_ID = b'AE' | |
EXTRA_WZ_AES = 0x9901 | |
WZ_SALT_LENGTHS = { | |
1: 8, # 128 bit | |
2: 12, # 192 bit | |
3: 16, # 256 bit | |
} | |
WZ_KEY_LENGTHS = { | |
1: 16, # 128 bit | |
2: 24, # 192 bit | |
3: 32, # 256 bit | |
} | |
class AESZipDecrypter(BaseDecrypter): | |
hmac_size = 10 | |
def __init__(self, zinfo, pwd): | |
self.zinfo = zinfo | |
self.name = zinfo.filename | |
if not pwd: | |
raise RuntimeError("File %r is encrypted, a password is " | |
"required for extraction" % self.name) | |
self.pwd = pwd | |
def start_decrypt(self, fileobj): | |
key_length = WZ_KEY_LENGTHS[self.zinfo.wz_aes_strength] | |
salt_length = WZ_SALT_LENGTHS[self.zinfo.wz_aes_strength] | |
# salt_length + pwd_verify_length | |
encryption_header_length = salt_length + 2 | |
encryption_header = fileobj.read(encryption_header_length) | |
salt = struct.unpack( | |
"<{}s".format(salt_length), | |
encryption_header[:salt_length] | |
)[0] | |
pwd_verify_length = 2 | |
pwd_verify = encryption_header[salt_length:] | |
dkLen = 2*key_length + pwd_verify_length | |
keymaterial = PBKDF2(self.pwd, salt, count=1000, dkLen=dkLen) | |
encpwdverify = keymaterial[2*key_length:] | |
if encpwdverify != pwd_verify: | |
raise RuntimeError("Bad password for file %r" % self.name) | |
enckey = keymaterial[:key_length] | |
self.decypter = AES.new( | |
enckey, | |
AES.MODE_CTR, | |
counter=Counter.new(nbits=128, little_endian=True) | |
) | |
encmac_key = keymaterial[key_length:2*key_length] | |
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash()) | |
return encryption_header_length + self.hmac_size | |
def decrypt(self, data): | |
self.hmac.update(data) | |
return self.decypter.decrypt(data) | |
def check_hmac(self, hmac_check): | |
if self.hmac.digest()[:10] != hmac_check: | |
raise BadZipFile("Bad HMAC check for file %r" % self.name) | |
class BaseZipEncrypter: | |
def update_zipinfo(self, zipinfo): | |
raise NotImplementedError( | |
'BaseZipEncrypter implementations must implement `update_zipinfo`.' | |
) | |
def encrypt(self, data): | |
raise NotImplementedError( | |
'BaseZipEncrypter implementations must implement `encrypt`.' | |
) | |
def encryption_header(self): | |
raise NotImplementedError( | |
'BaseZipEncrypter implementations must implement ' | |
'`encryption_header`.' | |
) | |
def flush(self): | |
return b'' | |
class AESZipEncrypter(BaseZipEncrypter): | |
hmac_size = 10 | |
def __init__(self, pwd, nbits=256, force_wz_aes_version=None): | |
if not pwd: | |
raise RuntimeError( | |
'%s encryption requires a password.' % WZ_AES | |
) | |
if nbits not in (128, 192, 256): | |
raise RuntimeError( | |
"`nbits` must be one of 128, 192, 256. Got '%s'" % nbits | |
) | |
self.force_wz_aes_version = force_wz_aes_version | |
salt_lengths = { | |
128: 8, | |
192: 12, | |
256: 16, | |
} | |
self.salt_length = salt_lengths[nbits] | |
key_lengths = { | |
128: 16, | |
192: 24, | |
256: 32, | |
} | |
key_length = key_lengths[nbits] | |
aes_strengths = { | |
128: 1, | |
192: 2, | |
256: 3, | |
} | |
self.aes_strength = aes_strengths[nbits] | |
self.salt = Random.new().read(self.salt_length) | |
pwd_verify_length = 2 | |
dkLen = 2 * key_length + pwd_verify_length | |
keymaterial = PBKDF2(pwd, self.salt, count=1000, dkLen=dkLen) | |
self.encpwdverify = keymaterial[2*key_length:] | |
enckey = keymaterial[:key_length] | |
self.encrypter = AES.new( | |
enckey, | |
AES.MODE_CTR, | |
counter=Counter.new(nbits=128, little_endian=True) | |
) | |
encmac_key = keymaterial[key_length:2*key_length] | |
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash()) | |
def update_zipinfo(self, zipinfo): | |
zipinfo.wz_aes_vendor_id = WZ_AES_VENDOR_ID | |
zipinfo.wz_aes_strength = self.aes_strength | |
if self.force_wz_aes_version is not None: | |
zipinfo.wz_aes_version = self.force_wz_aes_version | |
def encryption_header(self): | |
return self.salt + self.encpwdverify | |
def encrypt(self, data): | |
data = self.encrypter.encrypt(data) | |
self.hmac.update(data) | |
return data | |
def flush(self): | |
return struct.pack('<%ds' % self.hmac_size, self.hmac.digest()[:10]) | |
class AESZipInfo(ZipInfo): | |
"""Class with attributes describing each file in the ZIP archive.""" | |
# __slots__ on subclasses only need to contain the additional slots. | |
__slots__ = ( | |
'wz_aes_version', | |
'wz_aes_vendor_id', | |
'wz_aes_strength', | |
# 'wz_aes_actual_compression_type', | |
) | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.wz_aes_version = None | |
self.wz_aes_vendor_id = None | |
self.wz_aes_strength = None | |
def decode_extra_wz_aes(self, ln, extra_payload): | |
if ln == 7: | |
counts = struct.unpack("<H2sBH", extra_payload) | |
else: | |
raise BadZipFile( | |
"Corrupt extra field %04x (size=%d)" % (EXTRA_WZ_AES, ln)) | |
self.wz_aes_version = counts[0] | |
self.wz_aes_vendor_id = counts[1] | |
# 0x01 128-bit encryption key | |
# 0x02 192-bit encryption key | |
# 0x03 256-bit encryption key | |
self.wz_aes_strength = counts[2] | |
# the compression method is the one that would otherwise have been | |
# stored in the local and central headers for the file. For example, if | |
# the file is imploded, this field will contain the compression code 6. | |
# This is needed because a compression method of 99 is used to indicate | |
# the presence of an AES-encrypted file | |
self.compress_type = counts[3] | |
# self.wz_aes_actual_compression_type = counts[3] | |
def get_extra_decoders(self): | |
extra_decoders = super().get_extra_decoders() | |
extra_decoders[EXTRA_WZ_AES] = self.decode_extra_wz_aes | |
return extra_decoders | |
def encode_extra(self, crc, compress_type): | |
wz_aes_extra = b'' | |
if self.wz_aes_vendor_id is not None: | |
compress_type = WZ_AES_COMPRESS_TYPE | |
aes_version = self.wz_aes_version | |
if aes_version is None: | |
if self.file_size < 20 | self.compress_type == ZIP_BZIP2: | |
# The only difference between version 1 and 2 is the | |
# handling of the CRC values. For version 2 the CRC value | |
# is not used and must be set to 0. | |
# For small files, the CRC files can leak the contents of | |
# the encrypted data. | |
# For bzip2, the compression already has integrity checks | |
# so CRC is not required. | |
aes_version = WZ_AES_V2 | |
else: | |
aes_version = WZ_AES_V1 | |
if aes_version == WZ_AES_V2: | |
crc = 0 | |
wz_aes_extra = struct.pack( | |
"<3H2sBH", | |
EXTRA_WZ_AES, | |
7, # extra block body length: H2sBH | |
aes_version, | |
self.wz_aes_vendor_id, | |
self.wz_aes_strength, | |
self.compress_type, | |
) | |
return wz_aes_extra, crc, compress_type | |
def get_local_header_params(self, zip64=False): | |
params = super().get_local_header_params(zip64=zip64) | |
wz_aes_extra, crc, compress_type = self.encode_extra( | |
params["crc"], params["compress_type"]) | |
params["extra"] = params["extra"] + wz_aes_extra | |
params["crc"] = crc | |
params["compress_type"] = compress_type | |
return params | |
def get_central_directory_kwargs(self): | |
params = super().get_central_directory_kwargs() | |
wz_aes_extra, crc, compress_type = self.encode_extra( | |
params["crc"], params["compress_type"]) | |
params["extra"] = params["extra"] + wz_aes_extra | |
params["crc"] = crc | |
params["compress_type"] = compress_type | |
return params | |
class AESZipExtFile(ZipExtFile): | |
def check_wz_aes(self): | |
hmac_check = self._fileobj.read(self._decrypter.hmac_size) | |
self._decrypter.check_hmac(hmac_check) | |
def check_integrity(self): | |
if self._zinfo.wz_aes_version is not None: | |
self.check_wz_aes() | |
if self._expected_crc is not None and self._expected_crc != 0: | |
# Not part of the spec but still check the CRC if it is | |
# supplied when WZ_AES_V2 is specified (no CRC check and CRC | |
# should be 0). | |
self.check_crc() | |
elif self._zinfo.wz_aes_version != WZ_AES_V2: | |
# CRC value should be 0 for AES vendor version 2. | |
self.check_crc() | |
else: | |
super().check_integrity() | |
class AESZipWriteFile(_ZipWriteFile): | |
def __init__(self, zf, zinfo, zip64, encrypter): | |
super().__init__(zf, zinfo, zip64) | |
self.encrypter = encrypter | |
if self.encrypter: | |
self.write_encryption_header() | |
def write_encryption_header(self): | |
buf = self.encrypter.encryption_header() | |
self._compress_size += len(buf) | |
self._fileobj.write(buf) | |
def write(self, data): | |
if self.closed: | |
raise ValueError('I/O operation on closed file.') | |
nbytes = len(data) | |
self._file_size += nbytes | |
self._crc = crc32(data, self._crc) | |
if self._compressor: | |
data = self._compressor.compress(data) | |
if self.encrypter: | |
data = self.encrypter.encrypt(data) | |
self._compress_size += len(data) | |
self._fileobj.write(data) | |
return nbytes | |
def flush_data(self): | |
if self._compressor: | |
buf = self._compressor.flush() | |
else: | |
buf = b"" | |
if self.encrypter: | |
buf = self.encrypter.encrypt(buf) | |
buf += self.encrypter.flush() | |
self._compress_size += len(buf) | |
self._fileobj.write(buf) | |
class AESZipFile(ZipFile): | |
zipinfo_cls = AESZipInfo | |
zipextfile_cls = AESZipExtFile | |
zipwritefile_cls = AESZipWriteFile | |
def __init__(self, *args, **kwargs): | |
encryption = kwargs.pop('encryption', None) | |
encryption_kwargs = kwargs.pop('encryption_kwargs', None) | |
super().__init__(*args, **kwargs) | |
self.encryption = encryption | |
self.encryption_kwargs = encryption_kwargs | |
def get_decrypter(self, zinfo, pwd): | |
if zinfo.is_encrypted: | |
if zinfo.wz_aes_version is not None: | |
return AESZipDecrypter(zinfo, pwd) | |
return CRCZipDecrypter(zinfo, pwd) | |
def get_encrypter(self): | |
if self.encryption == WZ_AES: | |
if self.encryption_kwargs is None: | |
encryption_kwargs = {} | |
else: | |
encryption_kwargs = self.encryption_kwargs | |
return AESZipEncrypter(pwd=self.pwd, **encryption_kwargs) | |
def get_zipwritefile(self, zinfo, zip64, pwd, **kwargs): | |
encrypter = None | |
if pwd is not None or self.encryption is not None: | |
zinfo.flag_bits |= _MASK_ENCRYPTED | |
encrypter = self.get_encrypter() | |
encrypter.update_zipinfo(zinfo) | |
return self.zipwritefile_cls(self, zinfo, zip64, encrypter) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment