Skip to content

Instantly share code, notes, and snippets.

@yyogo
Created October 3, 2019 15:34
Show Gist options
  • Save yyogo/a4609d41e8bfd3a82c65a9bd5d08c415 to your computer and use it in GitHub Desktop.
Save yyogo/a4609d41e8bfd3a82c65a9bd5d08c415 to your computer and use it in GitHub Desktop.
""" ASN1 DER/BER symmetric parser/encoder. """
import enum
def decode_varint(bytearr):
value = bytearr[0] & 0x7f
while bytearr.pop(0) & 0x80:
value = (value << 7) | (bytearr[0] & 0x7f)
return value
def encode_varint(n):
res = [n & 0x7f]
n >>= 7
while n:
res.append((n & 0x7f) | 0x80)
n >>= 7
return bytes(reversed(res))
class BitField:
def __init__(self, value=0):
self.value = value
def __int__(self):
return self.value
def __index__(self):
return self.value
def __getitem__(self, sl):
if isinstance(sl, slice):
bits = bin(self.value)[2:][::-1]
return int(bits[sl][::-1] or '0', 2)
else:
return (self.value >> sl.__index__()) & 1
def __setitem__(self, sl, value):
if isinstance(sl, slice):
size = max(sl.start or 0, sl.stop or 0, self.value.bit_length())
bits = list(bin(self.value)[2:].zfill(size)[::-1])
bits[sl] = list(bin(value)[2:].zfill(len(bits[sl]))[::-1])
self.value = int(''.join(reversed(bits)), 2)
else:
self.value |= (bool(value) << sl.__index__())
def __repr__(self):
return f'<BitField({self.value:#x})>'
DER_TAG_EXTENDED = 0x1f
class ASN1Tag(enum.IntEnum):
Boolean = 1
Integer = 2
BitString = 3
OctetString = 4
Null = 5
OID = 6
UTF8String = 12
Sequence = 16
Set = 17
PrintableString = 19
T61String = 20
IA5String = 22
UTCTime = 23
GeneralString = 27
UniversalString = 28
class ASN1Class(enum.IntEnum):
Universal = 0
Application = 1
Context = 2
Private = 3
class DERObject:
CONS = CLASS = TAG = None
def __init__(self, value=None, data=None):
if not ((value is None) ^ (data is None)):
raise TypeError(f"{self.__class__.__name__}: must supply exactly one of `data`, `value`")
if value is None:
value = self._decode_value(data)
elif data is None:
data = self._encode_value(value)
self._data = data
self._value = value
def _encode_value(self, value):
return bytes(value)
def _decode_value(self, data):
return data
@property
def value(self):
return self._value
@value.setter
def setvalue(self, value):
self._value = value
self._data = self._encode_value(value)
@property
def data(self):
return self._data
@data.setter
def data(self, data):
self._value = self._decode_value(data)
def __bytes__(self):
# hack to suppoprt BER shit
length = len(self.data)
if getattr(self, '_indefinite', False):
length = None
return self._encode_header(length) + self.data
@classmethod
def _encode_header(cls, length):
tag, asn1cls, cons = cls.TAG, cls.CLASS, cls.CONS
if tag >= 0x1f: #extended
tag = DER_TAG_EXTENDED
head = BitField()
head[6:8] = asn1cls
head[5] = cons
head[:5] = tag
data = bytearray([head])
if tag == DER_TAG_EXTENDED:
data += encode_varint(cls.TAG)
if length is None:
data.append(0x80)
elif length < 0x80:
data.append(length)
else:
byte_count = (length.bit_length() + 7) // 8
data.append(0x80 | byte_count)
data += length.to_bytes(byte_count, 'big')
return bytes(data)
@classmethod
def _decode_header(cls, bytearr):
head = BitField(bytearr.pop(0))
asn_cls, cons, tag = head[6:8], head[5], head[:5]
if tag == DER_TAG_EXTENDED:
tag = decode_varint(bytearr)
if len(bytearr) == 0:
raise ValueError("truncated der header")
length = bytearr.pop(0)
if length > 0x80:
byte_count = length & 0x7f
if byte_count > len(bytearr):
raise ValueError("bad der length")
length = int.from_bytes(bytearr[:byte_count], 'big')
bytearr[:byte_count] = b''
elif length == 0x80:
length = None
return (tag, cons, asn_cls, length, )
@classmethod
def from_bytes(cls, data):
arr = bytearray(data)
tag, cons, asn1cls, length = cls._decode_header(arr)
asn1cls = ASN1Class(asn1cls)
if length is not None:
if len(arr) < length:
raise ValueError("Truncated DER object")
indefinite = False
else:
indefinite = True
length = len(arr)
cls = cls.find_class(tag, cons, asn1cls)
obj = cls(data=bytes(arr[:length]))
obj.header_length = len(data) - len(arr)
obj.length = obj.header_length + length
# hack
obj._indefinite = indefinite
return obj
def __repr__(self):
return f"{self.__class__.__name__}({self.value!r})"
@classmethod
def find_class(cls, tag, cons, asn1cls):
if cls.TAG == tag and cls.CONS == cons and cls.CLASS == asn1cls:
return cls
if cls.TAG is not None and cls.TAG != tag:
return None
if cls.CONS is not None and cls.CONS != cons:
return None
if cls.CLASS is not None and cls.CLASS != asn1cls:
return None
for subclass in cls.__subclasses__():
found = subclass.find_class(tag, cons, asn1cls)
if found is not None:
return found
# class not found; create new
return type(f"{ASN1Class(asn1cls).name}{'Cons' if cons else 'Prim'}[{tag:#x}]",
(cls,), {'TAG': tag, 'CLASS': asn1cls, 'CONS': cons})
class Constructed(DERObject):
CONS = True
def _encode_value(self, value):
return b''.join(bytes(x) for x in value)
def _decode_value(self, data):
value = []
while data:
decoded = DERObject.from_bytes(data)
value.append(decoded)
data = data[decoded.length:]
return value
def __iter__(self):
return iter(self.value)
def __getitem__(self, index):
return self.value[index]
def __len__(self):
return len(self.value)
class Universal(DERObject):
CLASS = ASN1Class.Universal
class UniversalConstructed(Universal, Constructed):
pass
class Sequence(UniversalConstructed):
TAG = ASN1Tag.Sequence
pass
class Set(UniversalConstructed):
TAG = ASN1Tag.Set
CLASS = ASN1Class.Universal
pass
class Primitive(DERObject):
CONS = False
def _decode_value(self, data):
return data
def _encode_value(self, value):
return bytes(value)
@property
def value(self):
return self._decode_value(self.data)
@value.setter
def setvalue(self, value):
self.data = self._encode_value(value)
class UniversalPrimitive(Universal, Primitive):
pass
class Null(UniversalPrimitive):
TAG = ASN1Tag.Null
def _decode_value(self, data):
assert len(data) == 0, "Null tag with data"
return None
def _encode_value(self, value):
assert value is None, "Null tag with non-None value"
return b''
class OID(UniversalPrimitive):
TAG = ASN1Tag.OID
def _decode_value(self, data):
oid = []
oid.append(data[0] // 40)
oid.append(data[0] % 40)
d = bytearray(data[1:])
while d:
oid.append(decode_varint(d))
return '.'.join(str(x) for x in oid)
def _encode_value(self, value):
if isinstance(value, str):
value = [int(x) for x in value.split('.')]
data = bytearray()
data.append(value[0] * 40 + value[1])
for n in value[2:]:
data += encode_varint(n)
return bytes(data)
def __str__(self):
return '.'.join(str(x) for x in self.oid())
class Integer(UniversalPrimitive):
TAG = ASN1Tag.Integer
def _decode_value(self, data):
return int.from_bytes(data, 'big')
def _encode_value(self, value):
byte_length = ((value.bit_length() + 7) // 8 )
return value.to_bytes(byte_length, 'big')
def __int__(self):
return self.value
class Boolean(UniversalPrimitive):
TAG = ASN1Tag.Boolean
def _decode_value(self, data):
return bool(int.from_bytes(data, 'big'))
def _encode_value(self, value):
byte_length = ((value.bit_length() + 7) // 8 )
return value.to_bytes(byte_length, 'big')
def __int__(self):
return int.from_bytes(self.data, 'big')
class StringType:
def _decode_value(self, data):
return data.decode(self.ENCODING)
def _encode_value(self, value):
return value.encode(self.ENCODING)
def __str__(self):
return str(self.value)
class UTF8String(StringType, UniversalPrimitive):
TAG = ASN1Tag.UTF8String
ENCODING = 'utf-8'
class GeneralString(StringType, UniversalPrimitive):
TAG = ASN1Tag.GeneralString
ENCODING = 'utf-8'
class UniversalString(StringType, UniversalPrimitive):
TAG = ASN1Tag.UniversalString
ENCODING = 'utf-8'
class IA5String(StringType, UniversalPrimitive):
TAG = ASN1Tag.IA5String
ENCODING = 'ascii'
class PrintableString(StringType, UniversalPrimitive):
TAG = ASN1Tag.PrintableString
ENCODING = 'ascii'
class OctetString(UniversalPrimitive):
TAG = ASN1Tag.OctetString
pass
class BitString(UniversalPrimitive):
TAG = ASN1Tag.BitString
def __int__(self):
return int.from_bytes(self.data, 'big')
if __name__ == '__main__':
import base64
# Google CA
test = base64.b64decode(
"""MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJV
UzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMSR2VvVHJ1c3QgR2xvYmFsIENBMB4XDTE3
MDUyMjExMzIzN1oXDTE4MTIzMTIzNTk1OVowSTELMAkGA1UEBhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJ
bmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRlcm5ldCBBdXRob3JpdHkgRzIwggEiMA0GCSqGSIb3DQEBAQUA
A4IBDwAwggEKAoIBAQCcKgR3XNhQkToGo4Lg2FBIvIk/8RlwGohGfuCPxfGJziHuWv5hDbcyRImgdAtT
T1WkzoJile7rWV/G4QWAEsRelD+8W0g49FP3JOb7kekVxM/0Uw30SvyfVN59vqBrb4fA0FAfKDADQNoI
c1Fsf/86PKc3Bo69SxEE630k3ub5/DFx+5TVYPMuSq9C0svqxGoassxT3RVLix/IGWEfzZ2oPmMrhDVp
ZYTIGcVGIvhTlb7jgEoQxirsupcgEcc5mRAEoPBhepUljE5SdeK27QjKFPzOImqzTs9GA5eXA37Asd57
r0Uzz7o+cbfe9CUlwg01iZ2d+w4ReYkeN8WvjnJpAgMBAAGjggERMIIBDTAfBgNVHSMEGDAWgBTAepho
jYn7qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wDgYDVR0PAQH/BAQD
AgEGMC4GCCsGAQUFBwEBBCIwIDAeBggrBgEFBQcwAYYSaHR0cDovL2cuc3ltY2QuY29tMBIGA1UdEwEB
/wQIMAYBAf8CAQAwNQYDVR0fBC4wLDAqoCigJoYkaHR0cDovL2cuc3ltY2IuY29tL2NybHMvZ3RnbG9i
YWwuY3JsMCEGA1UdIAQaMBgwDAYKKwYBBAHWeQIFATAIBgZngQwBAgIwHQYDVR0lBBYwFAYIKwYBBQUH
AwEGCCsGAQUFBwMCMA0GCSqGSIb3DQEBCwUAA4IBAQDKSeWs12Rkd1u+cfrP9B4jx5ppY1Rf60zWGSgj
ZGaOHMeHgGRfBIsmr5jfCnC8vBk97nszqX+99AXUcLsFJnnqmseYuQcZZTTMPOk/xQH6bwx+23pwXEz+
LQDwyr4tjrSogPsBE4jLnD/lu3fKOmc2887VJwJyQ6C9bgLxRwVxPgFZ6RGeGvOED4Cmong1L7bHon8X
fOGLVq7uZ4hRJzBgpWJSwzfVO+qFKgE4h6LPcK2kesnE58rF2rwjMvL+GMJ74N87L9TQEOaWTPtEtyFk
DbkAlDASJodYmDkFOA/MgkgMCkdm7r+0X8T/cKjhf4t5K7hlMqO5tzHpCvX2HzLc""")
obj = DERObject.from_bytes(test)
assert bytes(obj) == test
#print(repr(obj))
from rupy import pp
pp(obj)
@yyogo
Copy link
Author

yyogo commented Oct 3, 2019

Tested only in Python 3.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment