Last active
April 18, 2019 17:45
-
-
Save kerinin/fe4aba38262efc9aebcee3a61a82a49c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Source data, generated by some preprocessing pipeline or read from Kafka | |
var examples: DataStream[(Input,Target)] = null | |
// Algorithms to train against. An algorithm defines most of the values passed to SageMaker | |
// when creating training jobs and models. Algorithms can be defined statically or read from | |
// a live stream (ie Kafka). Algorithms have an associate "id" that can be used to train multiple | |
// algorithms against a single dataset. | |
var algorithms: DataStream[AlgorithmEvent] = null | |
// The `split_examples` uses a `Splitter` to partition an input stream into Training & Test datasets. | |
var (training, testing) = examples.split_examples(new DefaultSplitter[(Input,Target)]()) | |
// Training examples are accumulated into `TrainingBatch` records | |
// Training batches describe a set of SageMaker training job channels and makes no assumptions about | |
// the details of batching, aggregation, sampling, etc. Batches have an associated "key" that can be | |
// used to train multiple datasets against a single model. | |
val training_batches: DataStream[TrainingBatch[KEY]] = training | |
.process(new CustomTrainingAggregation()) | |
// Trains each batch against each algorithm. | |
// The most recent version of each algorithm is trained against the most recent training batch for | |
// each batch key. When new batches are received they are trained against the most recent version of each | |
// algorithm. When new algorithms are received they are trained against the most recent batch for | |
// each batch key. | |
val sagemaker_models: DataStream[SagemakerModel[KEY]] = training_batches.train( | |
algorithm_events = algorithms, | |
typeinfo = TypeInformation.of(new TypeHint[TrainingBatch[KEY]]() {}) // unfortunate boilerplate to handle generic keys | |
) | |
// Testing examples are accumulated into `TestingBatch` records | |
// Testing batches describe the inputs to a SageMaker transform job and a "target" value. | |
// As with training batches, no assumptions are made about batching, aggregation, sampling, etc. | |
// Testing batches use the same keying mechanism as training batches. | |
val testing_batches: DataStream[TestingBatch[KEY,TARGET]] = testing | |
.process(new CustomTestingAggregation()) | |
// Evaluates the performance of each testing batch / algorithm pair. | |
// `validate` uses a `Validator` to evaluate the results of a SageMaker transform job and produce a performance | |
// element. The performance definition is generic and must be provided. The `Validator` interface is initialized | |
// with a description of the training job and it's target values (as defined in the testing batch), then passed | |
// each SageMaker transform result file before producing the performance value. | |
val performance: DataStream[PERF] = testing_batches.connect(sagemaker_models) | |
.validate( | |
validator = new CustomValidator(), | |
typeinfo = TypeInformation.of(new TypeHint[TARGET]() {}) // unfortunate boilerplate to handle generic targets | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment