Skip to content

Instantly share code, notes, and snippets.

@fuyi
Created March 15, 2022 07:56
Show Gist options
  • Save fuyi/fcae63ad73c342f706566abbb788a9ad to your computer and use it in GitHub Desktop.
Save fuyi/fcae63ad73c342f706566abbb788a9ad to your computer and use it in GitHub Desktop.
write_feature
@beam.typehints.with_input_types(Tuple[User, List[EventAction]])
class WriteFeatureRPC(beam.DoFn):
def __init__(
self,
api_endpoint,
project_id,
location,
featurestore_id,
entity_type_id,
feature_name,
):
self.write_time_dist = Metrics.distribution("writer", "write_time")
self.serialize_time_dist = Metrics.distribution("writer", "serialize_time")
self.client_setup_time_dist = Metrics.distribution(
"writer", "client_setup_time"
)
self.project_id = project_id
self.location = location
self.featurestore_id = featurestore_id
self.entity_type_id = entity_type_id
self.api_endpoint = api_endpoint
self.feature_name = feature_name
def setup(self):
start_time = time.time()
self.client = self._create_client()
self.client_setup_time_dist.update(int((time.time() - start_time) * 1000))
self.buffer = []
def process(self, element):
start_time = time.time()
entity_id = element[0].id + "_" + element[0].market
feature_values = {
self.feature_name: featurestore_online_service_pb2.FeatureValue(
string_array_value=types_pb2.StringArray(
values=[e.article for e in element[1]]
)
),
"market": featurestore_online_service_pb2.FeatureValue(
string_value=element[0].market
),
}
payload = featurestore_online_service_pb2.WriteFeatureValuesPayload(
entity_id=entity_id, feature_values=feature_values
)
self.serialize_time_dist.update(int((time.time() - start_time) * 1000))
self.buffer.append(payload)
def _create_client(self):
return FeaturestoreOnlineServingServiceClient(
client_options={"api_endpoint": self.api_endpoint}
)
def finish_bundle(self):
if not self.client:
self.client = self._create_client()
start_time = time.time()
entity_type_path = self.client.entity_type_path(
self.project_id, self.location, self.featurestore_id, self.entity_type_id
)
self.client.write_feature_values(
entity_type=entity_type_path,
payloads=self.buffer,
)
self.write_time_dist.update(int((time.time() - start_time) * 1000))
self.buffer.clear()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment