This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def poke(self, context): | |
| hook = GCSHook( | |
| gcp_conn_id=self.google_cloud_conn_id, | |
| delegate_to=self.delegate_to, | |
| impersonation_chain=self.impersonation_chain, | |
| ) | |
| for bucket in self.buckets: | |
| match = [{'bucket': bucket, 'object': object} | |
| for object in hook.list(bucket)] | |
| self._matches.extend(match) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Deployment Checker | |
| description: The component checks whether to deploy model new version or not | |
| inputs: | |
| - name: project_id | |
| description: '' | |
| type: String | |
| - name: model_name | |
| description: '' | |
| type: String | |
| - name: model_folder |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def download_blob(source_bucket_name: str, source_blob_name: str, dest_file_path: OutputPath()): | |
| from google.cloud import storage | |
| storage_client = storage.Client() | |
| bucket = storage_client.bucket(source_bucket_name) | |
| blob = bucket.blob(source_blob_name) | |
| blob.download_to_filename(dest_file_path) | |
| get_model_location_task = GET_MODEL_LOCATION_OP(bucket_name, model_folder) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ml_engine_deploy_task = ML_ENGINE_DEPLOY_OP( | |
| model_uri=get_model_location_task.outputs['location'], | |
| project_id=project_id, | |
| model_id=MODEL_NAME, | |
| version_id='version_{}'.format(MODEL_FOLDER), | |
| runtime_version=RUNTIME_VERSION, | |
| python_version='3.5', | |
| set_default=True, | |
| wait_interval=WAIT_INTERVAL) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_model_location(bucket_name: str, model_folder: str) -> NamedTuple('returns', [('location', 'GCSPath')]): | |
| from google.cloud import storage | |
| folder = 'model/{}/export/exporter/'.format(model_folder) | |
| storage_client = storage.Client() | |
| bucket = storage_client.get_bucket(bucket_name) | |
| blobs = bucket.list_blobs(prefix=folder) | |
| blob_name = None | |
| for blob in blobs: | |
| tmp = len(blob.name.rstrip('/').split('/')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def trigger(project_id, bucket_name, train_steps, model_folder, host): | |
| client = kfp.Client(namespace='default', host=host) | |
| params = { | |
| 'project_id': project_id, | |
| 'bucket_name': bucket_name, | |
| 'train_steps': train_steps, | |
| 'model_folder': model_folder | |
| } | |
| experiments = client.list_experiments() | |
| experiment_id = experiments.experiments[0].id |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with dsl.Condition(is_retraining_needed_task.output == 'yes'): | |
| trigger_model_training_pipeline_task = \ | |
| trigger_model_training_pipeline_op(project_id, bucket_name, train_steps, model_folder, host) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def store_training_job_metrics_op(project_id: str, metrics_file_path: str, model_name: str, model_folder: str): | |
| return dsl.ContainerOp( | |
| name='store-training-job-metrics', | |
| image='gcr.io/{}/taxi-fare-utils:1.0'.format(project_id), | |
| command=['python3', '-m', 'metrics_writer'], | |
| arguments=[ | |
| '--metrics_file_path', metrics_file_path, | |
| '--model_name', model_name, | |
| '--model_folder', model_folder | |
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def download_blob(source_bucket_name: str, source_blob_name: str, dest_file_path: OutputPath()): | |
| from google.cloud import storage | |
| storage_client = storage.Client() | |
| bucket = storage_client.bucket(source_bucket_name) | |
| blob = bucket.blob(source_blob_name) | |
| blob.download_to_filename(dest_file_path) | |
| DOWNLOAD_BLOB_OP = comp.create_component_from_func(func=download_blob, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ML_ENGINE_TRAIN_OP = comp.load_component_from_url( | |
| 'https://raw.githubusercontent.com/kubeflow/pipelines/1.1.1-beta.1/components/gcp/ml_engine/train/component.yaml') | |
| ml_engine_train_task = ML_ENGINE_TRAIN_OP( | |
| project_id=project_id, | |
| python_module=PYTHON_MODULE, | |
| package_uris=package_uris, | |
| region=REGION, | |
| args=args, | |
| job_dir=job_dir, |