Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created July 14, 2020 16:59
Show Gist options
  • Save ntakouris/87765bc7aff78d9005607cff714ed178 to your computer and use it in GitHub Desktop.
Save ntakouris/87765bc7aff78d9005607cff714ed178 to your computer and use it in GitHub Desktop.
import tensorflow as tf
# beam imports
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import tensorflow_transform.beam as tft_beam
# orchestration
from tfx.orchestration import pipeline, metadata
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
# pipeline components
from tfx.components import CsvExampleGen
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.utils.dsl_utils import external_input
# config protos
from tfx.proto import example_gen_pb2
REGION = ''
PIPELINE_ROOT = 'gs://' + ''
STAGING = 'staging'
TEMP = 'temp'
PROJECT_ID = ''
JOB_NAME = ''
DATASET_PATTERN = 'taxi_dataset.csv'
DATAFLOW_BEAM_PIPELINE_ARGS = [
'--project=' + PROJECT_ID,
'--runner=DataflowRunner',
'--temp_location=' + f'{PIPELINE_ROOT}/{TEMP}',
'--staging_location=' + f'{PIPELINE_ROOT}/{STAGING}',
'--region=' + REGION,
'--experiments=shuffle_mode=service',
'--job-name=' + JOB_NAME,
]
def create_pipeline():
no_eval_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern=DATASET_PATTERN),
])
example_gen = CsvExampleGen(input=external_input(
PIPELINE_ROOT), input_config=no_eval_config)
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
return pipeline.Pipeline(
pipeline_name=f'Pipeline {JOB_NAME}',
pipeline_root=PIPELINE_ROOT,
components=[example_gen, statistics_gen, schema_gen],
beam_pipeline_args=DATAFLOW_BEAM_PIPELINE_ARGS,
metadata_connection_config=metadata.mysql_metadata_connection_config(
host="ip", database="db", port=3306, username='usr', password='pwd')
)
if __name__ == '__main__':
BeamDagRunner().run(create_pipeline())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment