Created
June 13, 2023 04:59
-
-
Save tomzx/d8cf8d4a34a2a3d989e0d2f262d4f8f4 to your computer and use it in GitHub Desktop.
Create TF serving configuration in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
syntax = "proto3"; | |
message ModelServerConfig { | |
ModelConfigList model_config_list = 1; | |
} | |
message ModelConfigList { | |
repeated ModelConfig config = 1; | |
} | |
message ModelConfig { | |
string name = 1; | |
string base_path = 2; | |
ServableVersionPolicy model_version_policy = 7; | |
} | |
message ServableVersionPolicy { | |
message Specific { | |
repeated int64 versions = 1; | |
} | |
Specific specific = 102; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from google.protobuf import text_format, json_format | |
import model_config_pb2 | |
def create_model_server_config(config): | |
model_server_config = model_config_pb2.ModelServerConfig() | |
for model_variant in config: | |
model_config = model_server_config.model_config_list.config.add() | |
model_config.name = model_variant["name"] | |
model_config.base_path = model_variant["base_path"] | |
specific = model_config.model_version_policy.Specific() | |
versions = [int(v) for v in model_variant["model_version_policy"]["specific"]["versions"]] | |
specific.versions.extend(versions) | |
model_config.model_version_policy.specific.CopyFrom(specific) | |
return model_server_config | |
def parse_proto_text(proto_text): | |
model_config_list = model_config_pb2.ModelConfigList() | |
model_config_list.ParseFromString(proto_text) | |
return model_config_list | |
if __name__ == '__main__': | |
with Path("model.fill.config").open() as f: | |
message = text_format.Parse(f.read(), model_config_pb2.ModelServerConfig()) | |
config = json_format.MessageToDict(message, preserving_proto_field_name=True) | |
config = config["model_config_list"]["config"] | |
model_server_config = create_model_server_config(config) | |
print(model_server_config) | |
# print(text_format.Parse(text_format.MessageToString(model_server_config), model_config_pb2.ModelServerConfig())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment