Last active
August 21, 2022 11:09
This Cloud Function will kick off a Vertex Pipeline run whenever a specified amount of new BigQuery data is available for training. Deploy it as an HTTP function and set up a Cloud Scheduler job to automate running it on a recurring basis. See this blog post for details: https://cloud.google.com/blog/topics/developers-practitioners/lets-get-it-s…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
import kfp | |
import json | |
import time | |
from google.cloud import bigquery | |
from google.cloud.exceptions import NotFound | |
from kfp.v2.google.client import AIPlatformClient | |
client = bigquery.Client() | |
RETRAIN_THRESHOLD = 1000 # Change this based on your use case | |
def insert_bq_data(table_id, num_rows): | |
rows_to_insert = [ | |
{u"num_rows_last_retraining": num_rows, u"last_retrain_time": time.time()} | |
] | |
errors = client.insert_rows_json(table_id, rows_to_insert) | |
if errors == []: | |
print("New rows have been added.") | |
else: | |
print(f"Encountered errors while inserting rows: {errors}") | |
def create_count_table(table_id, num_rows): | |
schema = [ | |
bigquery.SchemaField("num_rows_last_retraining", "INTEGER", mode="REQUIRED"), | |
bigquery.SchemaField("last_retrain_time", "TIMESTAMP", mode="REQUIRED") | |
] | |
table = bigquery.Table(table_id, schema=schema) | |
table = client.create_table(table) | |
print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}") | |
insert_bq_data(table_id, num_rows) | |
def create_pipeline_run(): | |
print('Kicking off a pipeline run...') | |
REGION = "us-central1" # Change this to the region you want to run in | |
api_client = AIPlatformClient( | |
project_id=client.project, | |
region=REGION, | |
) | |
try: | |
response = api_client.create_run_from_job_spec( | |
"compiled_pipeline.json", | |
pipeline_root="gs://your-gcs-bucket/pipeline_root/", | |
parameter_values={"project": client.project, "display_name": "pipeline_gcf_trigger"} | |
) | |
return response | |
except: | |
print("Error trying to run the pipeline") | |
raise | |
# This should be the entrypoint for your Cloud Function | |
def check_table_size(request): | |
request = request.get_data() | |
try: | |
request_json = json.loads(request.decode()) | |
except ValueError as e: | |
print(f"Error decoding JSON: {e}") | |
return "JSON Error", 400 | |
if request_json and 'bq_dataset' in request_json: | |
dataset = request_json['bq_dataset'] | |
table = request_json['bq_table'] | |
data_table = client.get_table(f"{client.project}.{dataset}.{table}") | |
current_rows = data_table.num_rows | |
print(f"{table} table has {current_rows} rows") | |
# See if `count` table exists in dataset | |
try: | |
count_table = client.get_table(f"{client.project}.{dataset}.count") | |
print("Count table exists, querying to see how many rows at last pipeline run") | |
except NotFound: | |
print("No count table found, creating one...") | |
create_count_table(f"{client.project}.{dataset}.count", current_rows) | |
query_job = client.query( | |
""" | |
SELECT num_rows_last_retraining FROM `your-project.your-dataset.count` | |
ORDER BY last_retrain_time DESC | |
LIMIT 1""" | |
) | |
results = query_job.result() | |
for i in results: | |
last_retrain_count = i[0] | |
rows_added_since_last_pipeline_run = current_rows - last_retrain_count | |
print(f"{rows_added_since_last_pipeline_run} rows have been added since we last ran the pipeline") | |
if (rows_added_since_last_pipeline_run >= RETRAIN_THRESHOLD): | |
pipeline_result = create_pipeline_run() | |
insert_bq_data(f"{client.project}.{dataset}.count", current_rows) | |
else: | |
return f"No BigQuery data given" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment