Skip to content

Instantly share code, notes, and snippets.

@wgaylord
Last active July 14, 2019 20:37
Show Gist options
  • Save wgaylord/cd1715cb8edb070540a3d4193d41367a to your computer and use it in GitHub Desktop.
Save wgaylord/cd1715cb8edb070540a3d4193d41367a to your computer and use it in GitHub Desktop.
Files to run.
"""
`pure-protobuf` contributors © 2011-2019
"""
from enum import IntEnum
# Wire types.
# See also https://developers.google.com/protocol-buffers/docs/encoding#structure
class WireType(IntEnum):
VARINT = 0
LONG_LONG = 1 # 64 bit
BYTES = 2
LONG = 5 # 32 bit
from legacy import MessageType, Unicode, UVarint, Float32
from legacy import Bool, EmbeddedMessage, Bytes, Flags
Version = MessageType()
Version.add_field(1,"version",UVarint)
Version.add_field(2,"release",Unicode)
Version.add_field(3,"os",Unicode)
Version.add_field(4,"os_version",Unicode)
Authenticate = MessageType()
Authenticate.add_field(1,"username",Unicode)
Authenticate.add_field(2,"password",Unicode)
Authenticate.add_field(3,"tokens",Unicode)
Ping = MessageType()
Ping.add_field(1,"timestamp",UVarint)
Ping.add_field(2,"good",UVarint)
Ping.add_field(3,"late",UVarint)
Ping.add_field(4,"lost",UVarint)
Ping.add_field(5,"resync",UVarint)
Ping.add_field(6,"udp_packets",UVarint)
Ping.add_field(7,"tcp_packets",UVarint)
Ping.add_field(8,"udp_ping_avg",Float32)
Ping.add_field(9,"udp_ping_var",Float32)
Ping.add_field(10,"tcp_ping_avg",Float32)
Ping.add_field(11,"tcp_ping_var",Float32)
Reject = MessageType()
Reject.add_field(1,"type",UVarint)
Reject.add_field(2,"reason",Unicode)
ServerSync = MessageType()
ServerSync.add_field(1,"session",UVarint)
ServerSync.add_field(2,"max_bandwidth",UVarint)
ServerSync.add_field(3,"welcome_text",Unicode)
ServerSync.add_field(4,"permissions",UVarint)
ChannelRemove = MessageType()
ChannelRemove.add_field(1,"channel_id",UVarint)
ChannelState = MessageType()
ChannelState.add_field(1,"channel_id",UVarint)
ChannelState.add_field(2,"parent",UVarint)
ChannelState.add_field(3,"name",Unicode)
ChannelState.add_field(4,"links",UVarint,flags=Flags.REPEATED)
ChannelState.add_field(5,"description",Unicode)
ChannelState.add_field(6,"links_add",UVarint,flags=Flags.REPEATED)
ChannelState.add_field(7,"links_remove",UVarint,flags=Flags.REPEATED)
ChannelState.add_field(8,"temporary",Bool)
ChannelState.add_field(9,"position",UVarint)
ChannelState.add_field(10,"description_hash",Bytes)
ChannelState.add_field(11,"max_users",UVarint)
UserRemove = MessageType()
UserRemove.add_field(1,"session",UVarint)
UserRemove.add_field(2,"actor",UVarint)
UserRemove.add_field(3,"reason",Unicode)
UserRemove.add_field(4,"ban",Bool)
UserState = MessageType()
UserState.add_field(1,"session",UVarint)
UserState.add_field(2,"actor",UVarint)
UserState.add_field(3,"name",Unicode)
UserState.add_field(4,"user_id",UVarint)
UserState.add_field(5,"channel_id",UVarint)
UserState.add_field(6,"mute",Bool)
UserState.add_field(7,"deaf",Bool)
UserState.add_field(8,"suppress",Bool)
UserState.add_field(9,"self_mute",Bool)
UserState.add_field(10,"self_deaf",Bool)
UserState.add_field(11,"texture",Bytes)
UserState.add_field(12,"plugin_context",Bytes)
UserState.add_field(13,"plugin_identity",Unicode)
UserState.add_field(14,"comment",Unicode)
UserState.add_field(15,"hash",Unicode)
UserState.add_field(16,"comment_hash",Unicode)
UserState.add_field(17,"texture_hash",Bytes)
UserState.add_field(18,"prority_speaker",Bool)
UserState.add_field(19,"recording",Bool)
_BanEntry = MessageType()
_BanEntry.add_field(1,"address",Bytes)
_BanEntry.add_field(2,"mask",UVarint)
_BanEntry.add_field(3,"name",Unicode)
_BanEntry.add_field(4,"hash",Unicode)
_BanEntry.add_field(5,"reason",Unicode)
_BanEntry.add_field(6,"start",Unicode)
_BanEntry.add_field(7,"duration",UVarint)
BanList = MessageType()
BanList.add_field(1,"bans",EmbeddedMessage(_BanEntry),flags = Flags.REPEATED)
BanList.add_field(2,"query",Bool)
TextMessage = MessageType()
TextMessage.add_field(1,"actor",UVarint)
TextMessage.add_field(2,"session",UVarint)
TextMessage.add_field(3,"channel_id",UVarint)
TextMessage.add_field(4,"tree_id",UVarint)
TextMessage.add_field(5,"message",Unicode)
PermissionDenied = MessageType()
PermissionDenied.add_field(1,"permission",UVarint)
PermissionDenied.add_field(2,"channel_id",UVarint)
PermissionDenied.add_field(3,"session",UVarint)
PermissionDenied.add_field(4,"reason",Unicode)
PermissionDenied.add_field(5,"type",UVarint)
PermissionDenied.add_field(6,"name",Unicode)
ChanGroup = MessageType()
ChanGroup.add_field(1,"name",Unicode)
ChanGroup.add_field(2,"inherited",Bool)
ChanGroup.add_field(3,"inherit",Bool)
ChanGroup.add_field(4,"inheritable",Bool)
ChanGroup.add_field(5,"add",UVarint)
ChanGroup.add_field(6,"remove",UVarint)
ChanGroup.add_field(7,"inherited_members",UVarint)
ChanACL = MessageType()
ChanACL.add_field(1,"apply_here",Bool)
ChanACL.add_field(2,"apply_subs",Bool)
ChanACL.add_field(3,"inherited",Bool)
ChanACL.add_field(4,"user_id",UVarint)
ChanACL.add_field(5,"group",Unicode)
ChanACL.add_field(6,"grant",UVarint)
ChanACL.add_field(7,"deny",UVarint)
ACL = MessageType()
ACL.add_field(1,"channel_id",UVarint)
ACL.add_field(2,"inherit_acls",Bool)
ACL.add_field(3,"groups",EmbeddedMessage(ChanGroup),flags=Flags.REPEATED)
ACL.add_field(4,"acls",EmbeddedMessage(ChanACL),flags=Flags.REPEATED)
ACL.add_field(5,"query",Bool)
QueryUsers = MessageType()
QueryUsers.add_field(1,"ids",UVarint,flags=Flags.REPEATED)
QueryUsers.add_field(2,"names",Unicode,flags=Flags.REPEATED)
CryptoSetup = MessageType()
CryptoSetup.add_field(1,"key",Bytes)
CryptoSetup.add_field(2,"client_nonce",Bytes)
CryptoSetup.add_field(3,"server_nonce",Bytes)
ContextActionModify = MessageType()
ContextActionModify.add_field(1,"action",Unicode)
ContextActionModify.add_field(2,"text",Unicode)
ContextActionModify.add_field(3,"context",UVarint)
ContextActionModify.add_field(4,"operation",UVarint)
ContextAction = MessageType()
ContextAction.add_field(1,"session",UVarint)
ContextAction.add_field(2,"channel_id",UVarint)
ContextAction.add_field(3,"action",Unicode)
User = MessageType()
User.add_field(1,"user_id",UVarint)
User.add_field(2,"name",Unicode)
User.add_field(3,"last_seen",Unicode)
User.add_field(4,"last_channel",UVarint)
UserList = MessageType()
UserList.add_field(1,"users",EmbeddedMessage(User),flags=Flags.REPEATED)
PermissionQuery = MessageType()
PermissionQuery.add_field(1,"channel_id",UVarint)
PermissionQuery.add_field(2,"permissions",UVarint)
PermissionQuery.add_field(3,"flush",Bool)
CodecVersion = MessageType()
CodecVersion.add_field(1,"alpha",UVarint)
CodecVersion.add_field(2,"beta",UVarint)
CodecVersion.add_field(3,"prefer_aplha",Bool)
CodecVersion.add_field(4,"opus",Bool)
ServerConfig = MessageType()
ServerConfig.add_field(1,"max_bandwidth",UVarint)
ServerConfig.add_field(2,"welcome_text",Unicode)
ServerConfig.add_field(3,"allow_html",Bool)
ServerConfig.add_field(4,"message_length",UVarint)
ServerConfig.add_field(5,"image_message_length",UVarint)
ServerConfig.add_field(6,"max_user",UVarint)
Calls = {0:Version,2:Authenticate,3:Ping,4:Reject,5:ServerSync,6:ChannelRemove,7:ChannelState,8:UserRemove,9:UserState,10:BanList,11:TextMessage,12:PermissionDenied,13:ACL,14:QueryUsers,15:CryptoSetup,16:ContextActionModify,17:ContextAction,18:UserList,20:PermissionQuery,21:CodecVersion,24:ServerConfig}
# coding: utf-8
"""
Legacy interface.
`pure-protobuf` contributors © 2011-2019
"""
import struct
from io import BytesIO
def b(string):
return bytes(string, encoding='latin=1')
class Type(object):
"""
Represents a general field type.
"""
def dump(self, fp, value):
"""
Dumps its value to write-like object.
"""
raise TypeError('Don\'t call this directly.')
def load(self, fp):
"""
Loads its value from read-like object and returns a read value.
"""
raise TypeError('Don\'t call this directly.')
def dumps(self, value):
"""
Dumps its value to string and returns this string.
"""
fp = BytesIO()
self.dump(fp, value)
return fp.getvalue()
def loads(self, s):
"""
Loads its value from a string and returns a read value.
"""
return self.load(BytesIO(s))
def __hash__(self):
"""
Returns a hash of this type.
"""
return hash(self.__class__.__name__)
class UVarintType(Type):
"""
Represents an unsigned Varint type.
"""
WIRE_TYPE = 0
def dump(self, fp, value):
shifted_value = True
while shifted_value:
shifted_value = value >> 7
fp.write(bytearray(((value & 0x7F) | (0x80 if shifted_value != 0 else 0x00),)))
value = shifted_value
def load(self, fp):
value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80:
quantum = ord(fp.read(1))
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
return value
class VarintType(UVarintType):
"""
Represents a signed Varint type. Implements ZigZag encoding.
"""
def dump(self, fp, value):
encoded_varint = abs(value) << 1
if value < 0:
encoded_varint -= 1
UVarintType.dump(self, fp, encoded_varint)
def load(self, fp):
encoded_varint = UVarintType.load(self, fp) + 1
div = encoded_varint >> 1
return div if encoded_varint & 1 else -div
class BoolType(UVarintType):
"""
Represents a boolean type. Encodes True as UVarint 1, and False as
UVarint 0.
"""
def dump(self, fp, value):
fp.write(b('\x01' if value else '\x00')) # similarly to UVarint
def load(self, fp):
return UVarintType.load(self, fp) != 0
class BytesType(Type):
"""
Represents a raw bytes type.
"""
WIRE_TYPE = 2
def dump(self, fp, value):
UVarint.dump(fp, len(value))
fp.write(value)
def load(self, fp):
return fp.read(UVarint.load(fp))
class UnicodeType(BytesType):
def dump(self, fp, value):
return BytesType.dump(self, fp, value.encode('utf-8'))
def load(self, fp):
return BytesType.load(self, fp).decode('utf-8')
class FixedLengthType(Type):
"""
Represents a general fixed-length value type. You should not use this type
directly. Use derived types instead.
"""
def dump(self, fp, value):
fp.write(value)
def load(self, fp):
return fp.read(self.length())
def length(self):
raise NotImplementedError()
class Fixed64Type(FixedLengthType):
"""
Represents a general 64-bit value type.
"""
WIRE_TYPE = 1
def length(self):
return 8
class Fixed32Type(FixedLengthType):
"""
Represents a general 32-bit value type.
"""
WIRE_TYPE = 5
def length(self):
return 4
class Fixed64SubType(Fixed64Type):
"""
Represents a general pickle'able 64-bit value type.
"""
dump = lambda self, fp, value: Fixed64Type.dump(
self, fp, struct.pack(self.format, value)
)
load = lambda self, fp: struct.unpack(
self.format, Fixed64Type.load(self, fp)
)[0]
class UInt64Type(Fixed64SubType):
"""
Represents an unsigned int64 type.
"""
format = '<Q'
class Int64Type(Fixed64SubType):
"""
Represents a signed int64 type.
"""
format = '<q'
class Float64Type(Fixed64SubType):
"""
Represents a double precision floating point type.
"""
format = '<d'
class Fixed32SubType(Fixed32Type):
"""
Represents a pickle'able 32-bit value.
"""
dump = lambda self, fp, value: Fixed32Type.dump(
self, fp, struct.pack(self.format, value)
)
load = lambda self, fp: struct.unpack(
self.format, Fixed32Type.load(self, fp)
)[0]
class UInt32Type(Fixed32SubType):
"""
Represents an unsigned int32 type.
"""
format = '<I'
class Int32Type(Fixed32SubType):
"""
Represents a signed int32 type.
"""
format = '<i'
class Float32Type(Fixed32SubType):
"""
Represents a single precision floating point type.
"""
format = '<f'
# Types instances. ------------------------------------------------------------
# You should use these types instances when defining your message type.
UVarint = UVarintType()
Varint = VarintType()
Bool = BoolType()
Fixed64 = Fixed64Type()
UInt64 = UInt64Type()
Int64 = Int64Type()
Float64 = Float64Type()
Fixed32 = Fixed32Type()
UInt32 = UInt32Type()
Int32 = Int32Type()
Float32 = Float32Type()
Bytes = BytesType()
Unicode = UnicodeType()
# Messages. -------------------------------------------------------------------
class Flags(object):
"""
Flags for a field.
"""
SIMPLE = 0 # Single value field.
REQUIRED, REQUIRED_MASK = 1, 1 # Required field_type.
# Repeated and packed-repeated fields.
SINGLE, REPEATED, PACKED_REPEATED, REPEATED_MASK = (
0,
2,
6,
6,
)
# Used by MessageMetaType to determine if a field contains embedded
# definition.
PRIMITIVE, EMBEDDED, EMBEDDED_MASK = (
0,
8,
8,
)
class EofWrapper:
"""
Wraps a stream to raise EOFError instead of just returning of ''.
"""
def __init__(self, fp, limit=None):
self.__fp = fp
self.__limit = limit
def read(self, size=None):
"""
Reads a string. Raises EOFError on end of stream.
"""
if size == 0:
return ''
if self.__limit is not None:
size = min(size, self.__limit)
self.__limit -= size
s = self.__fp.read(size)
if len(s) == 0:
raise EOFError()
return s
def _pack_key(tag, wire_type):
"""
Packs a tag and a wire_type into single int according to the protobuf spec.
"""
return (tag << 3) | wire_type
def _unpack_key(key):
"""
Unpacks a key into a tag and a wire_type according to the protobuf spec.
"""
return key >> 3, key & 7
# This used to correctly determine the length of unknown tags when loading a
# message.
_wire_type_to_type_instance = {0: Varint, 1: Fixed64, 2: Bytes, 5: Fixed32}
class MessageType(Type):
"""
Represents a message type.
"""
def __init__(self):
"""
Creates a new message type.
"""
self.__tags_to_types = dict() # Maps a tag to a type instance.
self.__tags_to_names = dict() # Maps a tag to a given field name.
self.__flags = dict() # Maps a tag to flags.
def __hash__(self):
_hash = 17
for tag, name, field_type, flags in iter(self):
_hash = hash((_hash, tag, field_type, flags))
return _hash
def __iter__(self):
"""
Iterates over all fields.
"""
for tag, name in self.__tags_to_names.items():
yield (tag, name, self.__tags_to_types[tag], self.__flags[tag])
def add_field(self, tag, name, field_type, flags=Flags.SIMPLE):
"""
Adds a field to the message type.
"""
if tag in self.__tags_to_names or tag in self.__tags_to_types:
raise ValueError('The tag %s is already used.' % tag)
self.__tags_to_names[tag] = name
self.__tags_to_types[tag] = field_type
self.__flags[tag] = flags
return self # Allow add_field chaining.
def remove_field(self, tag):
"""
Removes a field by its tag. Doesn't raise any exception when the tag is
missing.
"""
if tag in self.__tags_to_names:
del self.__tags_to_names[tag]
if tag in self.__tags_to_types:
del self.__tags_to_types[tag]
def __call__(self):
"""
Creates an instance of this message type.
"""
return Message(self)
def __has_flag(self, tag, flag, mask):
"""
Checks whether the field with the specified tag has the specified flag.
"""
return (self.__flags[tag] & mask) == flag
def dump(self, fp, value):
if self != value.message_type:
raise TypeError(
'Attempting to dump an object with type that\'s different '
'from mine.'
)
for tag, field_type in self.__tags_to_types.items():
if self.__tags_to_names[tag] in value:
if self.__has_flag(tag, Flags.SINGLE, Flags.REPEATED_MASK):
# Single value.
UVarint.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
field_type.dump(fp, value[self.__tags_to_names[tag]])
elif self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
# Repeated packed value.
UVarint.dump(fp, _pack_key(tag, Bytes.WIRE_TYPE))
internal_fp = BytesIO()
for single_value in value[self.__tags_to_names[tag]]:
field_type.dump(internal_fp, single_value)
Bytes.dump(fp, internal_fp.getvalue())
elif self.__has_flag(tag, Flags.REPEATED, Flags.REPEATED_MASK):
# Repeated value.
key = _pack_key(tag, field_type.WIRE_TYPE)
# Put it together sequently.
for single_value in value[self.__tags_to_names[tag]]:
UVarint.dump(fp, key)
field_type.dump(fp, single_value)
elif self.__has_flag(tag, Flags.REQUIRED, Flags.REQUIRED_MASK):
raise ValueError(
'The field with the tag %s is required but a value is '
'missing.' % tag
)
def load(self, fp):
fp, message = (
EofWrapper(fp),
self.__call__(),
) # Wrap fp and create a new instance.
while True:
try:
tag, wire_type = _unpack_key(UVarint.load(fp))
if tag in self.__tags_to_types:
field_type = self.__tags_to_types[tag]
if not self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
if wire_type != field_type.WIRE_TYPE:
raise TypeError(
'The received value with the tag %s has '
'incorrect wiretype: %s instead of %s '
'expected.'
% (tag, wire_type, field_type.WIRE_TYPE)
)
elif wire_type != Bytes.WIRE_TYPE:
raise TypeError(
'Tag %s has wiretype %s while the field is packed '
'repeated.'
% (tag, wire_type)
)
if self.__has_flag(tag, Flags.SINGLE, Flags.REPEATED_MASK):
# Single value.
message[self.__tags_to_names[tag]] = field_type.load(
fp
)
elif self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
# Repeated packed value.
repeated_value = message[
self.__tags_to_names[tag]
] = list()
internal_fp = EofWrapper(
fp, UVarint.load(fp)
) # Limit with value length.
while True:
try:
repeated_value.append(
field_type.load(internal_fp)
)
except EOFError:
break
elif self.__has_flag(
tag, Flags.REPEATED, Flags.REPEATED_MASK
):
# Repeated value.
if not self.__tags_to_names[tag] in message:
repeated_value = message[
self.__tags_to_names[tag]
] = list()
repeated_value.append(field_type.load(fp))
else:
# Skip this field.
_wire_type_to_type_instance[wire_type].load(fp)
except EOFError:
# Check if all required fields are present.
for tag, name in self.__tags_to_names.items():
has_flag = self.__has_flag(
tag, Flags.REQUIRED, Flags.REQUIRED_MASK
)
if has_flag and (name not in message):
if self.__has_flag(
tag, Flags.REPEATED, Flags.REPEATED_MASK
):
# Empty list (no values was in input stream).
# But required field.
message[name] = list()
else:
raise ValueError(
'The field with the tag %s (\'%s\') is '
'required but a value is missing.'
% (tag, name)
)
return message
class Message(dict):
"""
Represents a message instance.
"""
def __init__(self, message_type):
"""
Initializes a new instance of the specified message type.
"""
super(Message, self).__init__()
self.message_type = message_type
def __getattr__(self, name):
"""
Gets a value of the specified message field.
"""
return self.__getitem__(name)
def __setattr__(self, name, value):
"""
Sets a value of the specified message field.
"""
if name in self.__dict__:
setattr(self,name,value)
else:
self.__setitem__(name, value)
return value
def dumps(self):
"""
Dumps the message into a string.
"""
return self.message_type.dumps(self)
def dump(self, fp):
"""
Dumps the message into a write-like object.
"""
return self.message_type.dump(fp, self)
def loads(self, s, message_type):
"""
Loads a message of the specified message type from the string.
"""
return message_type.loads(s)
def load(self, fp, message_type):
"""
Loads a message of the specified message type from the read-like object.
"""
return message_type.load(fp)
# Embedded message. -----------------------------------------------------------
class EmbeddedMessage(Type):
"""
Represents an embedded message type.
"""
WIRE_TYPE = 2
def __init__(self, message_type):
"""
Initializes a new instance. The argument is an underlying message type.
"""
self.message_type = message_type
def __call__(self):
"""
Creates a message of the underlying message type.
"""
return self.message_type()
def dump(self, fp, value):
Bytes.dump(fp, self.message_type.dumps(value))
def load(self, fp):
return self.message_type.load(
EofWrapper(fp, UVarint.load(fp))
) # Limit with embedded message length.
import socket
import ssl
import time
import struct
from machine import Timer
#from apscheduler.schedulers.background import BackgroundScheduler
import ESPPacket
#scheduler = BackgroundScheduler()
Session = 0
Actors = {}
Channels = {}
sock = socket.socket()
sock.connect(("108.228.59.57",64738))
ssock = ssl.wrap_socket(sock)
timer = Timer(-1)
#@scheduler.scheduled_job('interval',seconds=30)
def Ping(test):
print("Sending Ping")
packet = ESPPacket.Ping()
packet.timestamp = int(time.time())
Encode(3,packet.dumps())
def ExchangeVersions(): #Fake it!
version = ESPPacket.Version.loads(Decode()[1])
version.message_type = ESPPacket.Version
Encode(0,version.dumps())
def Decode():
type = struct.unpack(">H",ssock.read(2))[0]
length = struct.unpack(">I",ssock.read(4))[0]
message = ssock.read(length)
return [type,message]
def Encode(type,message):
ssock.write(struct.pack(">H",type))
ssock.write(struct.pack(">I",len(message)))
ssock.write(message)
timer.init(period=30000,callback=Ping)
ExchangeVersions()
auth = ESPPacket.Authenticate()
auth.username = "testing"
print(auth)
Encode(2,auth.dumps())
#scheduler.start()
while True:
type,message = Decode()
if type == 3:
ping = ESPPacket.Ping.loads(message)
print("Ping: "+str(int(time.time())-ping.timestamp))
print(ping)
elif type == 9:
userState = ESPPacket.UserState.loads(message)
print(userState)
Session = userState.session
else:
if type in ESPPacket.Calls.keys():
print(ESPPacket.Calls[type].loads(message))
else:
print(type)
#!/usr/bin/env python3
# coding: utf-8
"""
Legacy interface.
"""
import struct
from io import BytesIO
# Types. ----------------------------------------------------------------------
def b(string):
return bytes(string, encoding='latin=1')
class Type(object):
"""
Represents a general field type.
"""
def dump(self, fp, value):
"""
Dumps its value to write-like object.
"""
raise TypeError('Don\'t call this directly.')
def load(self, fp):
"""
Loads its value from read-like object and returns a read value.
"""
raise TypeError('Don\'t call this directly.')
def dumps(self, value):
"""
Dumps its value to string and returns this string.
"""
fp = BytesIO()
self.dump(fp, value)
return fp.getvalue()
def loads(self, s):
"""
Loads its value from a string and returns a read value.
"""
return self.load(BytesIO(s))
def __hash__(self):
"""
Returns a hash of this type.
"""
return hash(self.__class__.__name__)
class UVarintType(Type):
"""
Represents an unsigned Varint type.
"""
WIRE_TYPE = 0
def dump(self, fp, value):
shifted_value = True
while shifted_value:
shifted_value = value >> 7
fp.write(bytearray(((value & 0x7F) | (0x80 if shifted_value != 0 else 0x00),)))
value = shifted_value
def load(self, fp):
value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80:
quantum = ord(fp.read(1))
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
return value
class VarintType(UVarintType):
"""
Represents a signed Varint type. Implements ZigZag encoding.
"""
def dump(self, fp, value):
encoded_varint = abs(value) << 1
if value < 0:
encoded_varint -= 1
UVarintType.dump(self, fp, encoded_varint)
def load(self, fp):
encoded_varint = UVarintType.load(self, fp) + 1
div = encoded_varint >> 1
return div if encoded_varint & 1 else -div
class BoolType(UVarintType):
"""
Represents a boolean type. Encodes True as UVarint 1, and False as
UVarint 0.
"""
def dump(self, fp, value):
fp.write(b('\x01' if value else '\x00')) # similarly to UVarint
def load(self, fp):
return UVarintType.load(self, fp) != 0
class BytesType(Type):
"""
Represents a raw bytes type.
"""
WIRE_TYPE = 2
def dump(self, fp, value):
UVarint.dump(fp, len(value))
fp.write(value)
def load(self, fp):
return fp.read(UVarint.load(fp))
class UnicodeType(BytesType):
def dump(self, fp, value):
return BytesType.dump(self, fp, value.encode('utf-8'))
def load(self, fp):
return BytesType.load(self, fp).decode('utf-8')
class FixedLengthType(Type):
"""
Represents a general fixed-length value type. You should not use this type
directly. Use derived types instead.
"""
def dump(self, fp, value):
fp.write(value)
def load(self, fp):
return fp.read(self.length())
def length(self):
raise NotImplementedError()
class Fixed64Type(FixedLengthType):
"""
Represents a general 64-bit value type.
"""
WIRE_TYPE = 1
def length(self):
return 8
class Fixed32Type(FixedLengthType):
"""
Represents a general 32-bit value type.
"""
WIRE_TYPE = 5
def length(self):
return 4
class Fixed64SubType(Fixed64Type):
"""
Represents a general pickle'able 64-bit value type.
"""
dump = lambda self, fp, value: Fixed64Type.dump(
self, fp, struct.pack(self.format, value)
)
load = lambda self, fp: struct.unpack(
self.format, Fixed64Type.load(self, fp)
)[0]
class UInt64Type(Fixed64SubType):
"""
Represents an unsigned int64 type.
"""
format = '>Q'
class Int64Type(Fixed64SubType):
"""
Represents a signed int64 type.
"""
format = '>q'
class Float64Type(Fixed64SubType):
"""
Represents a double precision floating point type.
"""
format = 'd'
class Fixed32SubType(Fixed32Type):
"""
Represents a pickle'able 32-bit value.
"""
dump = lambda self, fp, value: Fixed32Type.dump(
self, fp, struct.pack(self.format, value)
)
load = lambda self, fp: struct.unpack(
self.format, Fixed32Type.load(self, fp)
)[0]
class UInt32Type(Fixed32SubType):
"""
Represents an unsigned int32 type.
"""
format = '>I'
class Int32Type(Fixed32SubType):
"""
Represents a signed int32 type.
"""
format = '>i'
class Float32Type(Fixed32SubType):
"""
Represents a single precision floating point type.
"""
format = 'f'
# Types instances. ------------------------------------------------------------
# You should use these types instances when defining your message type.
UVarint = UVarintType()
Varint = VarintType()
Bool = BoolType()
Fixed64 = Fixed64Type()
UInt64 = UInt64Type()
Int64 = Int64Type()
Float64 = Float64Type()
Fixed32 = Fixed32Type()
UInt32 = UInt32Type()
Int32 = Int32Type()
Float32 = Float32Type()
Bytes = BytesType()
Unicode = UnicodeType()
# Messages. -------------------------------------------------------------------
class Flags(object):
"""
Flags for a field.
"""
SIMPLE = 0 # Single value field.
REQUIRED, REQUIRED_MASK = 1, 1 # Required field_type.
# Repeated and packed-repeated fields.
SINGLE, REPEATED, PACKED_REPEATED, REPEATED_MASK = (
0,
2,
6,
6,
)
# Used by MessageMetaType to determine if a field contains embedded
# definition.
PRIMITIVE, EMBEDDED, EMBEDDED_MASK = (
0,
8,
8,
)
class EofWrapper:
"""
Wraps a stream to raise EOFError instead of just returning of ''.
"""
def __init__(self, fp, limit=None):
self.__fp = fp
self.__limit = limit
def read(self, size=None):
"""
Reads a string. Raises EOFError on end of stream.
"""
if size == 0:
return ''
if self.__limit is not None:
size = min(size, self.__limit)
self.__limit -= size
s = self.__fp.read(size)
if len(s) == 0:
raise EOFError()
return s
def _pack_key(tag, wire_type):
"""
Packs a tag and a wire_type into single int according to the protobuf spec.
"""
return (tag << 3) | wire_type
def _unpack_key(key):
"""
Unpacks a key into a tag and a wire_type according to the protobuf spec.
"""
return key >> 3, key & 7
# This used to correctly determine the length of unknown tags when loading a
# message.
_wire_type_to_type_instance = {0: Varint, 1: Fixed64, 2: Bytes, 5: Fixed32}
class MessageType(Type):
"""
Represents a message type.
"""
def __init__(self):
"""
Creates a new message type.
"""
self.__tags_to_types = dict() # Maps a tag to a type instance.
self.__tags_to_names = dict() # Maps a tag to a given field name.
self.__flags = dict() # Maps a tag to flags.
def __hash__(self):
_hash = 17
for tag, name, field_type, flags in iter(self):
_hash = hash((_hash, tag, field_type, flags))
return _hash
def __iter__(self):
"""
Iterates over all fields.
"""
for tag, name in self.__tags_to_names.items():
yield (tag, name, self.__tags_to_types[tag], self.__flags[tag])
def add_field(self, tag, name, field_type, flags=Flags.SIMPLE):
"""
Adds a field to the message type.
"""
if tag in self.__tags_to_names or tag in self.__tags_to_types:
raise ValueError('The tag %s is already used.' % tag)
self.__tags_to_names[tag] = name
self.__tags_to_types[tag] = field_type
self.__flags[tag] = flags
return self # Allow add_field chaining.
def remove_field(self, tag):
"""
Removes a field by its tag. Doesn't raise any exception when the tag is
missing.
"""
if tag in self.__tags_to_names:
del self.__tags_to_names[tag]
if tag in self.__tags_to_types:
del self.__tags_to_types[tag]
def __call__(self):
"""
Creates an instance of this message type.
"""
return Message(self)
def __has_flag(self, tag, flag, mask):
"""
Checks whether the field with the specified tag has the specified flag.
"""
return (self.__flags[tag] & mask) == flag
def dump(self, fp, value):
if self != value.message_type:
raise TypeError(
'Attempting to dump an object with type that\'s different '
'from mine.'
)
for tag, field_type in self.__tags_to_types.items():
if self.__tags_to_names[tag] in value:
if self.__has_flag(tag, Flags.SINGLE, Flags.REPEATED_MASK):
# Single value.
UVarint.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
field_type.dump(fp, value[self.__tags_to_names[tag]])
elif self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
# Repeated packed value.
UVarint.dump(fp, _pack_key(tag, Bytes.WIRE_TYPE))
internal_fp = BytesIO()
for single_value in value[self.__tags_to_names[tag]]:
field_type.dump(internal_fp, single_value)
Bytes.dump(fp, internal_fp.getvalue())
elif self.__has_flag(tag, Flags.REPEATED, Flags.REPEATED_MASK):
# Repeated value.
key = _pack_key(tag, field_type.WIRE_TYPE)
# Put it together sequently.
for single_value in value[self.__tags_to_names[tag]]:
UVarint.dump(fp, key)
field_type.dump(fp, single_value)
elif self.__has_flag(tag, Flags.REQUIRED, Flags.REQUIRED_MASK):
raise ValueError(
'The field with the tag %s is required but a value is '
'missing.' % tag
)
def load(self, fp):
fp, message = (
EofWrapper(fp),
self.__call__(),
) # Wrap fp and create a new instance.
while True:
try:
tag, wire_type = _unpack_key(UVarint.load(fp))
if tag in self.__tags_to_types:
field_type = self.__tags_to_types[tag]
if not self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
if wire_type != field_type.WIRE_TYPE:
raise TypeError(
'The received value with the tag %s has '
'incorrect wiretype: %s instead of %s '
'expected.'
% (tag, wire_type, field_type.WIRE_TYPE)
)
elif wire_type != Bytes.WIRE_TYPE:
raise TypeError(
'Tag %s has wiretype %s while the field is packed '
'repeated.'
% (tag, wire_type)
)
if self.__has_flag(tag, Flags.SINGLE, Flags.REPEATED_MASK):
# Single value.
message[self.__tags_to_names[tag]] = field_type.load(
fp
)
elif self.__has_flag(
tag, Flags.PACKED_REPEATED, Flags.REPEATED_MASK
):
# Repeated packed value.
repeated_value = message[
self.__tags_to_names[tag]
] = list()
internal_fp = EofWrapper(
fp, UVarint.load(fp)
) # Limit with value length.
while True:
try:
repeated_value.append(
field_type.load(internal_fp)
)
except EOFError:
break
elif self.__has_flag(
tag, Flags.REPEATED, Flags.REPEATED_MASK
):
# Repeated value.
if not self.__tags_to_names[tag] in message:
repeated_value = message[
self.__tags_to_names[tag]
] = list()
repeated_value.append(field_type.load(fp))
else:
# Skip this field.
_wire_type_to_type_instance[wire_type].load(fp)
except EOFError:
# Check if all required fields are present.
for tag, name in self.__tags_to_names.items():
has_flag = self.__has_flag(
tag, Flags.REQUIRED, Flags.REQUIRED_MASK
)
if has_flag and (name not in message):
if self.__has_flag(
tag, Flags.REPEATED, Flags.REPEATED_MASK
):
# Empty list (no values was in input stream).
# But required field.
message[name] = list()
else:
raise ValueError(
'The field with the tag %s (\'%s\') is '
'required but a value is missing.'
% (tag, name)
)
return message
class Message(dict):
"""
Represents a message instance.
"""
def __init__(self, message_type):
"""
Initializes a new instance of the specified message type.
"""
super(Message, self).__init__()
self.__dict__['message_type'] = message_type
def __getattr__(self, name):
"""
Gets a value of the specified message field.
"""
return self.__getitem__(name)
def __setattr__(self, name, value):
"""
Sets a value of the specified message field.
"""
mapping = self.__dict__ if name in self.__dict__ else self
mapping.__setitem__(name, value)
return value
def dumps(self):
"""
Dumps the message into a string.
"""
return self.message_type.dumps(self)
def dump(self, fp):
"""
Dumps the message into a write-like object.
"""
return self.message_type.dump(fp, self)
def loads(self, s, message_type):
"""
Loads a message of the specified message type from the string.
"""
return message_type.loads(s)
def load(self, fp, message_type):
"""
Loads a message of the specified message type from the read-like object.
"""
return message_type.load(fp)
# Embedded message. -----------------------------------------------------------
class EmbeddedMessage(Type):
"""
Represents an embedded message type.
"""
WIRE_TYPE = 2
def __init__(self, message_type):
"""
Initializes a new instance. The argument is an underlying message type.
"""
self.message_type = message_type
def __call__(self):
"""
Creates a message of the underlying message type.
"""
return self.message_type()
def dump(self, fp, value):
Bytes.dump(fp, self.message_type.dumps(value))
def load(self, fp):
return self.message_type.load(
EofWrapper(fp, UVarint.load(fp))
) # Limit with embedded message length.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment