Skip to content

Instantly share code, notes, and snippets.

@Geremie
Last active November 16, 2020 22:58
Show Gist options
  • Save Geremie/1994a704ef6d860139c3435455281149 to your computer and use it in GitHub Desktop.
Save Geremie/1994a704ef6d860139c3435455281149 to your computer and use it in GitHub Desktop.
Easily improve your company machine learning system with cost-efficient automated model retraining
def is_retraining_needed(context, dag_run_obj):
current_version = utils.get_model_current_version(MODEL_NAME)
label_mean_query = "SELECT labelMean FROM training_jobs WHERE versionName = '" + current_version + "'"
predictions_query = "SELECT prediction FROM predictions WHERE versionName = '" + current_version + "'"
hook = PostgresHook('cloud_sql_proxy_conn')
try:
label_mean = float(hook.get_records(label_mean_query)[0][0])
except IndexError: # probably caused by manual training job not persisting metrics
label_mean = np.inf
print('label mean: {} '.format(label_mean))
_min = label_mean * (1 - ACCEPTABLE_GAP)
_max = label_mean * (1 + ACCEPTABLE_GAP)
prediction_mean = hook.get_pandas_df(predictions_query)['prediction'].mean()
print('prediction mean: {} '.format(prediction_mean))
if prediction_mean < _min or prediction_mean > _max:
return dag_run_obj
else:
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment