Last active
July 29, 2020 13:26
-
-
Save ntakouris/e8b23c1b5cd5222b69fbf5177081622e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import fluent_tfx as ftfx | |
current_dir = os.path.dirname( | |
os.path.realpath(__file__)) | |
# get the user code file | |
user_code_file = os.path.join(os.path.dirname( | |
os.path.realpath(__file__)), 'model_code.py') | |
print( | |
f'Using {user_code_file} for preprocessing, training and tuning functions') | |
# specify where to output data artifacts, for example, tfrecord files and the model output | |
bucket_uri = os.path.join(current_dir, 'bucket') | |
# model evaluation criteria with tensorflow model analysis | |
eval_config = tfma.EvalConfig(<fill in please>) | |
# you can move all the constants here to a different `constants.py` file, for better organisation | |
pipeline = ftfx.PipelineDef(name='simple_e2e', bucket=bucket_uri) \ | |
.with_sqlite_ml_metadata() \ | |
.from_csv(os.path.join(current_dir, 'data/')) \ | |
.generate_statistics() \ | |
.infer_schema(infer_feature_shape=True) \ | |
.preprocess(user_code_file) \ | |
.tune(user_code_file, | |
train_args=trainer_pb2.TrainArgs(num_steps=5), | |
eval_args=trainer_pb2.EvalArgs(num_steps=3)) \ | |
.train(user_code_file, | |
train_args=trainer_pb2.TrainArgs(num_steps=10), | |
eval_args=trainer_pb2.EvalArgs(num_steps=5)) \ | |
.evaluate_model(eval_config=eval_config, | |
example_provider_component=ftfx.input_builders.from_csv( | |
os.path.join(current_dir, 'data'), | |
name='eval_example_gen')) \ | |
.push_to(relative_push_uri='serving') \ | |
.bulk_infer(example_provider_component=ftfx.input_builders.from_csv( | |
uri=os.path.join(current_dir, 'to_infer'), | |
name='bulk_infer_example_gen' | |
)) | |
# you can also view all the generated components along with their names as keys: | |
print('Exposed pipeline components dict:') | |
print(pipeline.components) | |
BeamDagRunner().run(pipeline.build()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment