Forked from ragoragino/protobuf_spark_schema_convertor.py
Created
June 12, 2024 22:01
-
-
Save glallen01/b209735704a10278cd939fdb217e2eee to your computer and use it in GitHub Desktop.
Protobuf schema to Spark schema conversion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import google.protobuf.descriptor | |
from pyspark.sql.types import ( | |
StructType, | |
StructField, | |
StringType, | |
IntegerType, | |
LongType, | |
DoubleType, | |
FloatType, | |
ArrayType, | |
MapType, | |
BooleanType, | |
) | |
class ProtobufToSparkSchemaConvertor: | |
""" | |
ProtobufToSparkSchemaConvertor converts Protobuf schema to Spark schema. | |
It does that by walking recursively through the protobuf definition and | |
creating corresponding Spark type objects. | |
""" | |
# List from: https://googleapis.dev/python/protobuf/latest/google/protobuf/descriptor.html#google.protobuf.descriptor.FieldDescriptor.TYPE_BOOL | |
_primitive_types_map = { | |
google.protobuf.descriptor.FieldDescriptor.TYPE_BOOL: lambda: BooleanType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_BYTES: lambda: StringType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_DOUBLE: lambda: DoubleType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_ENUM: lambda: LongType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED32: lambda: IntegerType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED64: lambda: LongType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_FLOAT: lambda: FloatType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_INT32: lambda: IntegerType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_INT64: lambda: LongType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED32: lambda: IntegerType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED64: lambda: LongType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT32: lambda: IntegerType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT64: lambda: LongType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_STRING: lambda: StringType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT32: lambda: IntegerType(), | |
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT64: lambda: LongType(), | |
} | |
# List from: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/wrappers.proto | |
_wrapper_type_names_map = { | |
"DoubleValue": lambda: DoubleType(), | |
"FloatValue": lambda: FloatType(), | |
"Int64Value": lambda: LongType(), | |
"UInt64Value": lambda: LongType(), | |
"Int32Value": lambda: IntegerType(), | |
"UInt32Value": lambda: IntegerType(), | |
"BoolValue": lambda: BooleanType(), | |
"StringValue": lambda: StringType(), | |
"BytesValue": lambda: StringType(), | |
} | |
def get_schema(self, descriptor: google.protobuf.descriptor.Descriptor): | |
full_schema = [] | |
self._walk_protobuf_descriptor(descriptor, full_schema) | |
return StructType(full_schema) | |
def _is_int_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor): | |
return field_descriptor.type in [ | |
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED32, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_FIXED64, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_INT32, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_INT64, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED32, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SFIXED64, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT32, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_SINT64, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT32, | |
google.protobuf.descriptor.FieldDescriptor.TYPE_UINT64, | |
] | |
def _is_primitive_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor): | |
return field_descriptor.type in self._primitive_types_map | |
def _is_map_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor): | |
return ( | |
field_descriptor.message_type | |
and field_descriptor.message_type.has_options | |
and field_descriptor.message_type.GetOptions().map_entry | |
) | |
def _is_proto_wrapper_type(self, field_descriptor: google.protobuf.descriptor.FieldDescriptor): | |
return field_descriptor.message_type.name in self._wrapper_type_names_map | |
def _handle_map(self, field: google.protobuf.descriptor.FieldDescriptor): | |
# Maps hold key/value descriptors in "key" and "value" fields. | |
key = [inner_field for inner_field in field.message_type.fields if inner_field.name == "key"][0] | |
if self._is_int_type(key): | |
# Protobuf serializes integers as strings in maps: | |
# https://github.com/protocolbuffers/protobuf/issues/7769 | |
key_type = StringType() | |
else: | |
key_type = ProtobufToSparkSchemaConvertor._primitive_types_map[key.type]() | |
value = [ | |
inner_field for inner_field in field.message_type.fields if inner_field.name == "value" | |
][0] | |
# Handle primitive value type | |
if self._is_primitive_type(value): | |
value_type = self._primitive_types_map[value.type]() | |
# Handle map value type - we handle it here because map doesn't have proper fields. | |
elif self._is_map_type(value): | |
value_type = self._handle_map(value) | |
# Handle message value type | |
else: | |
schema_list = [] | |
self._walk_protobuf_descriptor(value.message_type, schema_list) | |
value_type = StructType(schema_list) | |
return MapType(key_type, value_type, True) | |
def _walk_protobuf_descriptor(self, descriptor: google.protobuf.descriptor.Descriptor, schema): | |
for field in descriptor.fields: | |
if field.type == google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE: | |
# Handle timestamp: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto | |
if field.message_type.name == "Timestamp": | |
spark_field = StructField(field.name, StringType(), True) | |
# Handle wrapper types: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/wrappers.proto | |
elif self._is_proto_wrapper_type(field): | |
type_factory = self._wrapper_type_names_map[field.message_type.name] | |
spark_field = StructField(field.name, type_factory(), True) | |
# Handle map type | |
elif self._is_map_type(field): | |
map_type = self._handle_map(field) | |
spark_field = StructField(field.name, map_type, True) | |
else: | |
# Handle message type | |
schema_list = [] | |
self._walk_protobuf_descriptor(field.message_type, schema_list) | |
spark_field = StructField(field.name, StructType(schema_list), True) | |
else: | |
# Handle primitive type | |
type_factory = self._primitive_types_map[field.type] | |
if not type_factory: | |
raise ValueError(f"Missing primitive type for: {field.name} | {field.type}") | |
spark_field = StructField(field.name, type_factory(), True) | |
# Handle array type | |
if ( | |
field.label == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED | |
and not self._is_map_type(field) | |
): | |
spark_field = StructField(field.name, ArrayType(spark_field.dataType, True), True) | |
schema.append(spark_field) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment