Last active
April 26, 2020 09:11
-
-
Save syossan27/10bd11cfd0494fc84fe9d7c8c49be756 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import requests | |
| import re | |
| import json | |
| from airflow import DAG | |
| from airflow.contrib.sensors.gcs_sensor import GoogleCloudStoragePrefixSensor | |
| from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator | |
| from airflow.exceptions import AirflowException | |
| from airflow.hooks.http_hook import HttpHook | |
| from airflow.operators.http_operator import SimpleHttpOperator | |
| from airflow.operators.python_operator import PythonOperator | |
| from airflow.utils.dates import days_ago | |
| from datetime import timedelta, datetime | |
| from google.cloud import automl_v1beta1 as automl | |
| cloud_functions_url = 'https://asia-northeast1-inference-pipeline.cloudfunctions.net' | |
| metadata_url = 'http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=' | |
| project_id = 'inference-pipeline' | |
| automl_tables_region = 'us-central1' | |
| model_id = 'TBL0000000000000000000' | |
| web_hook_url = 'https://hooks.slack.com/services/hoge' | |
| web_hook_name = 'hoge' | |
| slack_channel_name = 'hoge channel' | |
| def success(status): | |
| dag_name = re.findall(r'.*:\s(.*)>', str(status['dag']))[0] | |
| data = { | |
| 'username': web_hook_name, | |
| 'channel': slack_channel_name, | |
| 'attachments': [{ | |
| 'fallback': dag_name, | |
| 'color': '#1e88e5', | |
| 'title': dag_name, | |
| 'text': f'{dag_name} is success' | |
| }] | |
| } | |
| requests.post(web_hook_url, json.dumps(data)) | |
| def fail(status): | |
| dag_name = re.findall(r'.*:\s(.*)>', str(status['dag']))[0] | |
| task_name = re.findall(r'.*:\s(.*)>', str(status['task']))[0] | |
| data = { | |
| 'username': web_hook_name, | |
| 'channel': slack_channel_name, | |
| 'attachments': [{ | |
| 'fallback': f'{dag_name}:{task_name}', | |
| 'color': '#e53935', | |
| 'title': f'{dag_name}:{task_name}', | |
| 'text': f'{task_name} is fail' | |
| }] | |
| } | |
| requests.post(web_hook_url, json.dumps(data)) | |
| dag = DAG( | |
| 'inference_pipeline', | |
| default_args={ | |
| 'start_date': days_ago(1), | |
| 'retries': 1, | |
| 'retry_delay': timedelta(minutes=5), | |
| 'on_failure_callback': fail #タスク失敗時にfailを実行し、Slackに通知 | |
| }, | |
| schedule_interval='@daily', | |
| dagrun_timeout=timedelta(minutes=60), | |
| catchup=False | |
| ) | |
| class RunCloudFunctionsOperator(SimpleHttpOperator): | |
| def execute(self, context): | |
| http = HttpHook(self.method, http_conn_id=self.http_conn_id) | |
| self.log.info("Calling HTTP method") | |
| target_audience = cloud_functions_url + self.endpoint | |
| fetch_instance_id_token_url = metadata_url + target_audience | |
| r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False) | |
| idt = r.text | |
| self.headers = {'Authorization': "Bearer " + idt} | |
| response = http.run(self.endpoint, | |
| self.data, | |
| self.headers, | |
| self.extra_options) | |
| if self.response_check: | |
| if not self.response_check(response): | |
| raise AirflowException("Response check returned False.") | |
| class RunCloudFunctionsWithOptOperator(SimpleHttpOperator): | |
| def execute(self, context): | |
| http = HttpHook(self.method, http_conn_id=self.http_conn_id) | |
| self.log.info("Calling HTTP method") | |
| target_audience = cloud_functions_url + self.endpoint | |
| fetch_instance_id_token_url = metadata_url + target_audience | |
| r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False) | |
| idt = r.text | |
| gcs_output_dir = context['ti'].xcom_pull(key='predicted output directory', task_ids='predict') | |
| endpoint = self.endpoint + f'?dir={gcs_output_dir}' | |
| self.headers = {'Authorization': "Bearer " + idt} | |
| response = http.run(endpoint, | |
| self.data, | |
| self.headers, | |
| self.extra_options) | |
| if self.response_check: | |
| if not self.response_check(response): | |
| raise AirflowException("Response check returned False.") | |
| csv_sensor = GoogleCloudStoragePrefixSensor( | |
| task_id='csv_sensor', | |
| bucket='test', | |
| prefix='data/{}-'.format(datetime.now().strftime('%Y%m%d')), | |
| timeout=60 * 60 * 24 * 2, | |
| pool='csv_sensor', | |
| dag=dag | |
| ) | |
| preprocessing = RunCloudFunctionsOperator( | |
| task_id='preprocessing', | |
| method='GET', | |
| http_conn_id='http_default', | |
| endpoint='/preprocessing', | |
| headers={}, | |
| xcom_push=False, | |
| response_check=lambda response: False if response.status_code != 200 else True, | |
| dag=dag, | |
| ) | |
| import_bq = GoogleCloudStorageToBigQueryOperator( | |
| task_id='import_bq', | |
| bucket='test', | |
| source_objects=['preprocess_data/*.csv'], | |
| source_format='CSV', | |
| allow_quoted_newlines=True, | |
| skip_leading_rows=1, | |
| destination_project_dataset_table='test.data', | |
| schema_fields=[ | |
| {'name': 'id', 'type': 'INTEGER'}, | |
| ], | |
| write_disposition='WRITE_TRUNCATE', | |
| dag=dag | |
| ) | |
| postprocessing = RunCloudFunctionsWithOptOperator( | |
| task_id='postprocessing', | |
| method='GET', | |
| http_conn_id='http_default', | |
| endpoint='/postprocessing', | |
| headers={}, | |
| xcom_push=False, | |
| response_check=lambda response: False if response.status_code != 200 else True, | |
| dag=dag, | |
| on_success_callback=success #タスク成功時にSlackに通知 | |
| ) | |
| def do_deploy_model(): | |
| client = automl.AutoMlClient() | |
| model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
| response = client.deploy_model(model_full_id) | |
| print(u'Model deployment finished. {}'.format(response.result())) | |
| return | |
| def do_predict(**kwargs): | |
| client = automl.AutoMlClient() | |
| model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
| predict_client = automl.PredictionServiceClient() | |
| input_uri = 'bq://inference-pipeline.test.data' | |
| input_config = {"bigquery_source": {"input_uri": input_uri}} | |
| output_uri = 'gs://test/predicted_data' | |
| output_config = {"gcs_destination": {"output_uri_prefix": output_uri}} | |
| response = predict_client.batch_predict(model_full_id, input_config, output_config) | |
| response.result() | |
| result = response.metadata | |
| gcs_output_dir = result.batch_predict_details.output_info.gcs_output_directory | |
| kwargs['ti'].xcom_push(key='predicted output directory', value=gcs_output_dir) | |
| print(u'Predict finished. {}'.format(response.result())) | |
| return | |
| def do_delete_model(): | |
| client = automl.AutoMlClient() | |
| model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
| response = client.undeploy_model(model_full_id) | |
| print(u'Model delete finished. {}'.format(response.result())) | |
| return | |
| deploy_model = PythonOperator( | |
| task_id='deploy_model', | |
| dag=dag, | |
| python_callable=do_deploy_model, | |
| ) | |
| predict = PythonOperator( | |
| task_id='predict', | |
| dag=dag, | |
| provide_context=True, | |
| python_callable=do_predict, | |
| ) | |
| delete_model = PythonOperator( | |
| task_id='delete_model', | |
| trigger_rule='all_done', | |
| dag=dag, | |
| python_callable=do_delete_model, | |
| ) | |
| # タスク依存関係の設定 | |
| csv_sensor >> preprocessing >> import_bq >> deploy_model >> predict >> delete_model >> postprocessing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment