Skip to content

Instantly share code, notes, and snippets.

@anna-geller
Created November 15, 2022 01:36
Show Gist options
  • Save anna-geller/44b6e0992cea9e49cd54c34962816ad5 to your computer and use it in GitHub Desktop.
Save anna-geller/44b6e0992cea9e49cd54c34962816ad5 to your computer and use it in GitHub Desktop.
"""
prefect deployment build week_2_dataflow/main.py:taxi_data -n yellow -q default -a
prefect deployment build week_2_dataflow/main.py:taxi_data -n yellow -q default -a --param table_name=green_tripdata
prefect deployment build week_2_dataflow/main.py:parent -n yellow -q default -a
prefect deployment build week_2_dataflow/main.py:parent -n yellow -q default -a --param table_name=green_tripdata
"""
import awswrangler as wr
import pandas as pd
from prefect import task, flow, get_run_logger
from prefect.task_runners import SequentialTaskRunner
from prefect.blocks.system import JSON
from week_2_dataflow.pandas_bq_block import BigQueryPandas
@task
def get_files_to_process(table_name: str):
block = JSON.load(table_name.replace("_", "-"))
return block.value["files"]
@task
def extract_from_s3(file_name: str) -> pd.DataFrame:
logger = get_run_logger()
raw_df = wr.s3.read_parquet(f"s3://nyc-tlc/trip data/{file_name}")
logger.info("Extracted %s with %d rows", file_name, len(raw_df))
return raw_df
@task
def load(df: pd.DataFrame, file: str, tbl: str, **kwargs) -> None:
logger = get_run_logger()
block = BigQueryPandas.load("default")
block.load_data(dataframe=df, table_name=tbl, **kwargs)
ref = block.credentials.get_bigquery_client().get_table(tbl)
logger.info(
"Loaded %s to %s βœ… table now has %d rows and %s GB",
file,
tbl,
ref.num_rows,
ref.num_bytes / 1_000_000_000,
)
@flow(task_runner=SequentialTaskRunner())
def taxi_data(
file: str = "yellow_tripdata_2022-06.parquet",
dataset: str = "trips_data_all",
table_name: str = "yellow_tripdata",
**kwargs,
):
tbl = f"{dataset}.{table_name}"
block = BigQueryPandas.load("default")
block.create_dataset_if_not_exists(dataset)
df = extract_from_s3.with_options(name=f"πŸ—‚οΈ extract_{file}").submit(file)
load.with_options(name=f"πŸš€ load_{file}").submit(df, file, tbl, **kwargs)
@flow(task_runner=SequentialTaskRunner())
def parent(
dataset: str = "trips_data_all", table_name: str = "yellow_tripdata", **kwargs
):
files = get_files_to_process(table_name)
tbl = f"{dataset}.{table_name}"
block = BigQueryPandas.load("default")
block.create_dataset_if_not_exists(dataset)
for file in files:
df = extract_from_s3.with_options(name=f"πŸ—‚οΈ extract_{file}").submit(file)
load.with_options(name=f"πŸš€ load_{file}").submit(df, file, tbl, **kwargs)
if __name__ == "__main__":
taxi_data(if_exists="replace")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment