Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save glallen01/b209735704a10278cd939fdb217e2eee to your computer and use it in GitHub Desktop.
Save glallen01/b209735704a10278cd939fdb217e2eee to your computer and use it in GitHub Desktop.
Protobuf schema to Spark schema conversion
import google.protobuf.descriptor
from pyspark.sql.types import (
StructType,
StructField,
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
ArrayType,
MapType,
BooleanType,
)
class ProtobufToSparkSchemaConvertor:
"""
ProtobufToSparkSchemaConvertor converts Protobuf schema to Spark schema.
It does that by walking recursively through the protobuf definition and
creating corresponding Spark type objects.
"""
# List from: https://googleapis.dev/python/protobuf/latest/google/protobuf/descriptor.html#google.protobuf.descriptor.FieldDescriptor.TYPE_BOOL
_primitive_types_map = {
google.protobuf.descriptor.FieldDescriptor.TYPE_BOOL: lambda: BooleanType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_BYTES: lambda: StringType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_DOUBLE: lambda: DoubleType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_ENUM: lambda: LongType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED32: lambda: IntegerType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED64: lambda: LongType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_FLOAT: lambda: FloatType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_INT32: lambda: IntegerType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_INT64: lambda: LongType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED32: lambda: IntegerType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED64: lambda: LongType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT32: lambda: IntegerType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT64: lambda: LongType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_STRING: lambda: StringType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT32: lambda: IntegerType(),
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT64: lambda: LongType(),
}
# List from: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/wrappers.proto
_wrapper_type_names_map = {
"DoubleValue": lambda: DoubleType(),
"FloatValue": lambda: FloatType(),
"Int64Value": lambda: LongType(),
"UInt64Value": lambda: LongType(),
"Int32Value": lambda: IntegerType(),
"UInt32Value": lambda: IntegerType(),
"BoolValue": lambda: BooleanType(),
"StringValue": lambda: StringType(),
"BytesValue": lambda: StringType(),
}
def get_schema(self, descriptor: google.protobuf.descriptor.Descriptor):
full_schema = []
self._walk_protobuf_descriptor(descriptor, full_schema)
return StructType(full_schema)
def _is_int_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor):
return field_descriptor.type in [
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED32,
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED64,
google.protobuf.descriptor.FieldDescriptor.TYPE_INT32,
google.protobuf.descriptor.FieldDescriptor.TYPE_INT64,
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED32,
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED64,
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT32,
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT64,
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT32,
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT64,
]
def _is_primitive_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor):
return field_descriptor.type in self._primitive_types_map
def _is_map_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor):
return (
field_descriptor.message_type
and field_descriptor.message_type.has_options
and field_descriptor.message_type.GetOptions().map_entry
)
def _is_proto_wrapper_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor):
return field_descriptor.message_type.name in self._wrapper_type_names_map
def _handle_map(self, field: google.protobuf.descriptor.FieldDescriptor):
# Maps hold key/value descriptors in "key" and "value" fields.
key = [inner_field for inner_field in field.message_type.fields if inner_field.name == "key"][0]
if self._is_int_type(key):
# Protobuf serializes integers as strings in maps:
# https://github.com/protocolbuffers/protobuf/issues/7769
key_type = StringType()
else:
key_type = ProtobufToSparkSchemaConvertor._primitive_types_map[key.type]()
value = [
inner_field for inner_field in field.message_type.fields if inner_field.name == "value"
][0]
# Handle primitive value type
if self._is_primitive_type(value):
value_type = self._primitive_types_map[value.type]()
# Handle map value type - we handle it here because map doesn't have proper fields.
elif self._is_map_type(value):
value_type = self._handle_map(value)
# Handle message value type
else:
schema_list = []
self._walk_protobuf_descriptor(value.message_type, schema_list)
value_type = StructType(schema_list)
return MapType(key_type, value_type, True)
def _walk_protobuf_descriptor(self, descriptor: google.protobuf.descriptor.Descriptor, schema):
for field in descriptor.fields:
if field.type == google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE:
# Handle timestamp: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto
if field.message_type.name == "Timestamp":
spark_field = StructField(field.name, StringType(), True)
# Handle wrapper types: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/wrappers.proto
elif self._is_proto_wrapper_type(field):
type_factory = self._wrapper_type_names_map[field.message_type.name]
spark_field = StructField(field.name, type_factory(), True)
# Handle map type
elif self._is_map_type(field):
map_type = self._handle_map(field)
spark_field = StructField(field.name, map_type, True)
else:
# Handle message type
schema_list = []
self._walk_protobuf_descriptor(field.message_type, schema_list)
spark_field = StructField(field.name, StructType(schema_list), True)
else:
# Handle primitive type
type_factory = self._primitive_types_map[field.type]
if not type_factory:
raise ValueError(f"Missing primitive type for: {field.name} | {field.type}")
spark_field = StructField(field.name, type_factory(), True)
# Handle array type
if (
field.label == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED
and not self._is_map_type(field)
):
spark_field = StructField(field.name, ArrayType(spark_field.dataType, True), True)
schema.append(spark_field)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment