Skip to content

Instantly share code, notes, and snippets.

@atemate
Created December 22, 2021 07:51
Show Gist options
  • Save atemate/20c0c40068a254f36d159d5a54194beb to your computer and use it in GitHub Desktop.
Save atemate/20c0c40068a254f36d159d5a54194beb to your computer and use it in GitHub Desktop.
Get BQML job statistics
import time
from google.cloud.bigquery import magics
def get_job_statistics(job, model_name):
assert job.state == "DONE", job.state
m = client.get_model(f"{PROJECT_ID}.{DATASET}.{model_name}")
if len(m.training_runs) != 1:
print(f"WARNING: Multiple training runs (taking the first): {model}")
r = m.training_runs[0]
opts = r.training_options
metr = r.evaluation_metrics.regression_metrics
training_table = r.data_split_result.training_table
evaluation_table = r.data_split_result.evaluation_table
return {
"model": {
"training_time_minutes": (job.ended - job.started).total_seconds()/60,
"project": str(m.project),
"dataset_id": str(m.dataset_id),
"model_id": str(m.model_id),
"location": str(m.location),
"training": {
"training_options": {
"max_iterations": int(opts.max_iterations),
"max_tree_depth": int(opts.max_tree_depth),
"learn_rate": float(opts.learn_rate),
"subsample": float(opts.subsample),
"early_stop": bool(opts.early_stop.value),
# protobuf's DoubleValue -> float:
"l1_regularization": float(opts.l1_regularization.value),
"l2_regularization": float(opts.l2_regularization.value),
"min_relative_progress": float(opts.min_relative_progress.value),
},
"evaluation_metrics": {
# protobuf's DoubleValue -> float:
"mean_absolute_error": float(metr.mean_absolute_error.value),
"mean_squared_error": float(metr.mean_squared_error.value),
"mean_squared_log_error": float(metr.mean_squared_log_error.value),
"median_absolute_error": float(metr.median_absolute_error.value),
"r_squared": float(metr.r_squared.value),
},
"start_time_seconds": r.start_time.seconds,
"data_split_result": {
"training_table": {
"project_id": str(training_table.project_id),
"dataset_id": str(training_table.dataset_id),
"table_id": str(training_table.table_id),
},
"evaluation_table": {
"project_id": str(evaluation_table.project_id),
"dataset_id": str(evaluation_table.dataset_id),
"table_id": str(evaluation_table.table_id),
},
},
},
},
"bigquery_job": job.to_api_repr(),
"total_mb_processed": (job.total_bytes_processed or 0) //(1024*1024),
"total_mb_billed": (job.total_bytes_billed or 0) //(1024*1024),
}
model_name = "..."
client = bigquery.Client(project=PROJECT_ID)
client.query(f"CREATE MODEL `{PROJECT_ID}.{DATASET}.{model_name}` ...")
while job.state != "DONE":
job.reload()
time.sleep(1)
print(job.state)
print(get_job_statistics(job, model_name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment