Last active
December 19, 2021 12:54
-
-
Save MachineLearningIsEasy/80bc0ef00c30b16840be80e13445207c to your computer and use it in GitHub Desktop.
Airflow dag with ml-model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import time | |
import pandas as pd | |
from airflow import DAG | |
from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator | |
from airflow.operators.postgres_operator import PostgresOperator | |
from airflow.models import Variable | |
import psycopg2 | |
import os | |
import sys | |
import sklearn | |
from sklearn.svm import SVC | |
import joblib | |
import pickle | |
os.environ["HOST_DB"] = '192.168.176.4' | |
os.environ["HOST_DB_port"] = '5432' | |
os.environ["DB_name"] = 'rainbow_database' | |
os.environ["USER_DB"] ='unicorn_user' | |
os.environ["PASSWORD_DB"] = 'magical_password' | |
import os | |
args = { | |
'owner': 'dimon', | |
'start_date':datetime.datetime(2018, 11, 1), | |
'provide_context':True | |
} | |
def get_data_and_predict(**kwargs): | |
ti = kwargs['ti'] | |
conn = psycopg2.connect(host=os.environ["HOST_DB"], port = os.environ["HOST_DB_port"], database=os.environ["DB_name"], user=os.environ["USER_DB"], password=os.environ["PASSWORD_DB"]) | |
cur = conn.cursor() | |
cur.execute("SELECT * FROM iris limit 1") | |
query_results = cur.fetchall() | |
col_names = [] | |
for elt in cur.description: | |
col_names.append(elt[0]) | |
conn.close() | |
cur.close() | |
data = pd.DataFrame(query_results, columns=col_names) | |
clf = joblib.load('/usr/local/airflow/dags/clf.pkl') | |
ti.xcom_push(key='iris_predictions', value=(data[data.columns[0]][0],clf.predict(data[data.columns[1:]])[0])) | |
with DAG('ml_model_predict', description='ml_model_predict', schedule_interval='*/1 * * * *', catchup=False,default_args=args) as dag: #0 * * * * */1 * * * * | |
get_data_and_predict = PythonOperator(task_id='get_data_and_predict', python_callable=get_data_and_predict) | |
insert_in_table = PostgresOperator( | |
task_id="insert_in_table", | |
postgres_conn_id="database_PG", | |
sql=[f"""INSERT INTO iris_predict VALUES( | |
{{{{ ti.xcom_pull(key='iris_predictions', task_ids=['get_data_and_predict'])[0][0] }}}}, | |
{{{{ ti.xcom_pull(key='iris_predictions', task_ids=['get_data_and_predict'])[0][1] }}}}) | |
"""] | |
) | |
get_data_and_predict >> insert_in_table | |
# Добавить в docker-compose | |
# database: | |
# image: "postgres" # use latest official postgres version | |
# env_file: | |
# - database.env # configure postgres | |
# ports: | |
# - "5423:5432" | |
# database.env | |
# POSTGRES_USER=unicorn_user | |
# POSTGRES_PASSWORD=magical_password | |
# POSTGRES_DB=rainbow_database |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment