Created
October 3, 2019 15:34
-
-
Save yyogo/a4609d41e8bfd3a82c65a9bd5d08c415 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" ASN1 DER/BER symmetric parser/encoder. """ | |
import enum | |
def decode_varint(bytearr): | |
value = bytearr[0] & 0x7f | |
while bytearr.pop(0) & 0x80: | |
value = (value << 7) | (bytearr[0] & 0x7f) | |
return value | |
def encode_varint(n): | |
res = [n & 0x7f] | |
n >>= 7 | |
while n: | |
res.append((n & 0x7f) | 0x80) | |
n >>= 7 | |
return bytes(reversed(res)) | |
class BitField: | |
def __init__(self, value=0): | |
self.value = value | |
def __int__(self): | |
return self.value | |
def __index__(self): | |
return self.value | |
def __getitem__(self, sl): | |
if isinstance(sl, slice): | |
bits = bin(self.value)[2:][::-1] | |
return int(bits[sl][::-1] or '0', 2) | |
else: | |
return (self.value >> sl.__index__()) & 1 | |
def __setitem__(self, sl, value): | |
if isinstance(sl, slice): | |
size = max(sl.start or 0, sl.stop or 0, self.value.bit_length()) | |
bits = list(bin(self.value)[2:].zfill(size)[::-1]) | |
bits[sl] = list(bin(value)[2:].zfill(len(bits[sl]))[::-1]) | |
self.value = int(''.join(reversed(bits)), 2) | |
else: | |
self.value |= (bool(value) << sl.__index__()) | |
def __repr__(self): | |
return f'<BitField({self.value:#x})>' | |
DER_TAG_EXTENDED = 0x1f | |
class ASN1Tag(enum.IntEnum): | |
Boolean = 1 | |
Integer = 2 | |
BitString = 3 | |
OctetString = 4 | |
Null = 5 | |
OID = 6 | |
UTF8String = 12 | |
Sequence = 16 | |
Set = 17 | |
PrintableString = 19 | |
T61String = 20 | |
IA5String = 22 | |
UTCTime = 23 | |
GeneralString = 27 | |
UniversalString = 28 | |
class ASN1Class(enum.IntEnum): | |
Universal = 0 | |
Application = 1 | |
Context = 2 | |
Private = 3 | |
class DERObject: | |
CONS = CLASS = TAG = None | |
def __init__(self, value=None, data=None): | |
if not ((value is None) ^ (data is None)): | |
raise TypeError(f"{self.__class__.__name__}: must supply exactly one of `data`, `value`") | |
if value is None: | |
value = self._decode_value(data) | |
elif data is None: | |
data = self._encode_value(value) | |
self._data = data | |
self._value = value | |
def _encode_value(self, value): | |
return bytes(value) | |
def _decode_value(self, data): | |
return data | |
@property | |
def value(self): | |
return self._value | |
@value.setter | |
def setvalue(self, value): | |
self._value = value | |
self._data = self._encode_value(value) | |
@property | |
def data(self): | |
return self._data | |
@data.setter | |
def data(self, data): | |
self._value = self._decode_value(data) | |
def __bytes__(self): | |
# hack to suppoprt BER shit | |
length = len(self.data) | |
if getattr(self, '_indefinite', False): | |
length = None | |
return self._encode_header(length) + self.data | |
@classmethod | |
def _encode_header(cls, length): | |
tag, asn1cls, cons = cls.TAG, cls.CLASS, cls.CONS | |
if tag >= 0x1f: #extended | |
tag = DER_TAG_EXTENDED | |
head = BitField() | |
head[6:8] = asn1cls | |
head[5] = cons | |
head[:5] = tag | |
data = bytearray([head]) | |
if tag == DER_TAG_EXTENDED: | |
data += encode_varint(cls.TAG) | |
if length is None: | |
data.append(0x80) | |
elif length < 0x80: | |
data.append(length) | |
else: | |
byte_count = (length.bit_length() + 7) // 8 | |
data.append(0x80 | byte_count) | |
data += length.to_bytes(byte_count, 'big') | |
return bytes(data) | |
@classmethod | |
def _decode_header(cls, bytearr): | |
head = BitField(bytearr.pop(0)) | |
asn_cls, cons, tag = head[6:8], head[5], head[:5] | |
if tag == DER_TAG_EXTENDED: | |
tag = decode_varint(bytearr) | |
if len(bytearr) == 0: | |
raise ValueError("truncated der header") | |
length = bytearr.pop(0) | |
if length > 0x80: | |
byte_count = length & 0x7f | |
if byte_count > len(bytearr): | |
raise ValueError("bad der length") | |
length = int.from_bytes(bytearr[:byte_count], 'big') | |
bytearr[:byte_count] = b'' | |
elif length == 0x80: | |
length = None | |
return (tag, cons, asn_cls, length, ) | |
@classmethod | |
def from_bytes(cls, data): | |
arr = bytearray(data) | |
tag, cons, asn1cls, length = cls._decode_header(arr) | |
asn1cls = ASN1Class(asn1cls) | |
if length is not None: | |
if len(arr) < length: | |
raise ValueError("Truncated DER object") | |
indefinite = False | |
else: | |
indefinite = True | |
length = len(arr) | |
cls = cls.find_class(tag, cons, asn1cls) | |
obj = cls(data=bytes(arr[:length])) | |
obj.header_length = len(data) - len(arr) | |
obj.length = obj.header_length + length | |
# hack | |
obj._indefinite = indefinite | |
return obj | |
def __repr__(self): | |
return f"{self.__class__.__name__}({self.value!r})" | |
@classmethod | |
def find_class(cls, tag, cons, asn1cls): | |
if cls.TAG == tag and cls.CONS == cons and cls.CLASS == asn1cls: | |
return cls | |
if cls.TAG is not None and cls.TAG != tag: | |
return None | |
if cls.CONS is not None and cls.CONS != cons: | |
return None | |
if cls.CLASS is not None and cls.CLASS != asn1cls: | |
return None | |
for subclass in cls.__subclasses__(): | |
found = subclass.find_class(tag, cons, asn1cls) | |
if found is not None: | |
return found | |
# class not found; create new | |
return type(f"{ASN1Class(asn1cls).name}{'Cons' if cons else 'Prim'}[{tag:#x}]", | |
(cls,), {'TAG': tag, 'CLASS': asn1cls, 'CONS': cons}) | |
class Constructed(DERObject): | |
CONS = True | |
def _encode_value(self, value): | |
return b''.join(bytes(x) for x in value) | |
def _decode_value(self, data): | |
value = [] | |
while data: | |
decoded = DERObject.from_bytes(data) | |
value.append(decoded) | |
data = data[decoded.length:] | |
return value | |
def __iter__(self): | |
return iter(self.value) | |
def __getitem__(self, index): | |
return self.value[index] | |
def __len__(self): | |
return len(self.value) | |
class Universal(DERObject): | |
CLASS = ASN1Class.Universal | |
class UniversalConstructed(Universal, Constructed): | |
pass | |
class Sequence(UniversalConstructed): | |
TAG = ASN1Tag.Sequence | |
pass | |
class Set(UniversalConstructed): | |
TAG = ASN1Tag.Set | |
CLASS = ASN1Class.Universal | |
pass | |
class Primitive(DERObject): | |
CONS = False | |
def _decode_value(self, data): | |
return data | |
def _encode_value(self, value): | |
return bytes(value) | |
@property | |
def value(self): | |
return self._decode_value(self.data) | |
@value.setter | |
def setvalue(self, value): | |
self.data = self._encode_value(value) | |
class UniversalPrimitive(Universal, Primitive): | |
pass | |
class Null(UniversalPrimitive): | |
TAG = ASN1Tag.Null | |
def _decode_value(self, data): | |
assert len(data) == 0, "Null tag with data" | |
return None | |
def _encode_value(self, value): | |
assert value is None, "Null tag with non-None value" | |
return b'' | |
class OID(UniversalPrimitive): | |
TAG = ASN1Tag.OID | |
def _decode_value(self, data): | |
oid = [] | |
oid.append(data[0] // 40) | |
oid.append(data[0] % 40) | |
d = bytearray(data[1:]) | |
while d: | |
oid.append(decode_varint(d)) | |
return '.'.join(str(x) for x in oid) | |
def _encode_value(self, value): | |
if isinstance(value, str): | |
value = [int(x) for x in value.split('.')] | |
data = bytearray() | |
data.append(value[0] * 40 + value[1]) | |
for n in value[2:]: | |
data += encode_varint(n) | |
return bytes(data) | |
def __str__(self): | |
return '.'.join(str(x) for x in self.oid()) | |
class Integer(UniversalPrimitive): | |
TAG = ASN1Tag.Integer | |
def _decode_value(self, data): | |
return int.from_bytes(data, 'big') | |
def _encode_value(self, value): | |
byte_length = ((value.bit_length() + 7) // 8 ) | |
return value.to_bytes(byte_length, 'big') | |
def __int__(self): | |
return self.value | |
class Boolean(UniversalPrimitive): | |
TAG = ASN1Tag.Boolean | |
def _decode_value(self, data): | |
return bool(int.from_bytes(data, 'big')) | |
def _encode_value(self, value): | |
byte_length = ((value.bit_length() + 7) // 8 ) | |
return value.to_bytes(byte_length, 'big') | |
def __int__(self): | |
return int.from_bytes(self.data, 'big') | |
class StringType: | |
def _decode_value(self, data): | |
return data.decode(self.ENCODING) | |
def _encode_value(self, value): | |
return value.encode(self.ENCODING) | |
def __str__(self): | |
return str(self.value) | |
class UTF8String(StringType, UniversalPrimitive): | |
TAG = ASN1Tag.UTF8String | |
ENCODING = 'utf-8' | |
class GeneralString(StringType, UniversalPrimitive): | |
TAG = ASN1Tag.GeneralString | |
ENCODING = 'utf-8' | |
class UniversalString(StringType, UniversalPrimitive): | |
TAG = ASN1Tag.UniversalString | |
ENCODING = 'utf-8' | |
class IA5String(StringType, UniversalPrimitive): | |
TAG = ASN1Tag.IA5String | |
ENCODING = 'ascii' | |
class PrintableString(StringType, UniversalPrimitive): | |
TAG = ASN1Tag.PrintableString | |
ENCODING = 'ascii' | |
class OctetString(UniversalPrimitive): | |
TAG = ASN1Tag.OctetString | |
pass | |
class BitString(UniversalPrimitive): | |
TAG = ASN1Tag.BitString | |
def __int__(self): | |
return int.from_bytes(self.data, 'big') | |
if __name__ == '__main__': | |
import base64 | |
# Google CA | |
test = base64.b64decode( | |
"""MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJV | |
UzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMSR2VvVHJ1c3QgR2xvYmFsIENBMB4XDTE3 | |
MDUyMjExMzIzN1oXDTE4MTIzMTIzNTk1OVowSTELMAkGA1UEBhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJ | |
bmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRlcm5ldCBBdXRob3JpdHkgRzIwggEiMA0GCSqGSIb3DQEBAQUA | |
A4IBDwAwggEKAoIBAQCcKgR3XNhQkToGo4Lg2FBIvIk/8RlwGohGfuCPxfGJziHuWv5hDbcyRImgdAtT | |
T1WkzoJile7rWV/G4QWAEsRelD+8W0g49FP3JOb7kekVxM/0Uw30SvyfVN59vqBrb4fA0FAfKDADQNoI | |
c1Fsf/86PKc3Bo69SxEE630k3ub5/DFx+5TVYPMuSq9C0svqxGoassxT3RVLix/IGWEfzZ2oPmMrhDVp | |
ZYTIGcVGIvhTlb7jgEoQxirsupcgEcc5mRAEoPBhepUljE5SdeK27QjKFPzOImqzTs9GA5eXA37Asd57 | |
r0Uzz7o+cbfe9CUlwg01iZ2d+w4ReYkeN8WvjnJpAgMBAAGjggERMIIBDTAfBgNVHSMEGDAWgBTAepho | |
jYn7qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wDgYDVR0PAQH/BAQD | |
AgEGMC4GCCsGAQUFBwEBBCIwIDAeBggrBgEFBQcwAYYSaHR0cDovL2cuc3ltY2QuY29tMBIGA1UdEwEB | |
/wQIMAYBAf8CAQAwNQYDVR0fBC4wLDAqoCigJoYkaHR0cDovL2cuc3ltY2IuY29tL2NybHMvZ3RnbG9i | |
YWwuY3JsMCEGA1UdIAQaMBgwDAYKKwYBBAHWeQIFATAIBgZngQwBAgIwHQYDVR0lBBYwFAYIKwYBBQUH | |
AwEGCCsGAQUFBwMCMA0GCSqGSIb3DQEBCwUAA4IBAQDKSeWs12Rkd1u+cfrP9B4jx5ppY1Rf60zWGSgj | |
ZGaOHMeHgGRfBIsmr5jfCnC8vBk97nszqX+99AXUcLsFJnnqmseYuQcZZTTMPOk/xQH6bwx+23pwXEz+ | |
LQDwyr4tjrSogPsBE4jLnD/lu3fKOmc2887VJwJyQ6C9bgLxRwVxPgFZ6RGeGvOED4Cmong1L7bHon8X | |
fOGLVq7uZ4hRJzBgpWJSwzfVO+qFKgE4h6LPcK2kesnE58rF2rwjMvL+GMJ74N87L9TQEOaWTPtEtyFk | |
DbkAlDASJodYmDkFOA/MgkgMCkdm7r+0X8T/cKjhf4t5K7hlMqO5tzHpCvX2HzLc""") | |
obj = DERObject.from_bytes(test) | |
assert bytes(obj) == test | |
#print(repr(obj)) | |
from rupy import pp | |
pp(obj) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Tested only in Python 3.7