This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from smdebug.rules import invoke_rule | |
| from smdebug.trials import create_trial | |
| trial = create_trial(path=’./smd_outputs/<JOB_NAME>) | |
| rule_obj = CustomVanishingGradientRule(trial, threshold=0.0001) | |
| invoke_rule(rule_obj, start_step=0, end_step=None) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from tensorflow.keras.callbacks import ModelCheckpoint | |
| checkpoint_path = "/opt/ml/checkpoints" | |
| checkpoint_names = 'cifar10-'+model_type+'.{epoch:03d}.h5' | |
| checkpoint_callback = ModelCheckpoint(filepath=f'{checkpoint_path}/{checkpoint_names}', | |
| save_weights_only=False, | |
| monitor='val_loss') | |
| model.fit(train_dataset, ... | |
| epochs=epochs, | |
| initial_epoch=epoch_number, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Import os, re | |
| from tensorflow.keras.models import load_model | |
| def load_checkpoint_model(checkpoint_path): | |
| files = [f for f in os.listdir(checkpoint_path) if f.endswith('.' + 'h5')] | |
| epoch_numbers = [re.search('(?<=\.)(.*[0-9])(?=\.)',f).group() for f in files] | |
| max_epoch_number = max(epoch_numbers) | |
| max_epoch_index = epoch_numbers.index(max_epoch_number) | |
| max_epoch_filename = files[max_epoch_index] | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sagemaker.tensorflow import TensorFlow | |
| bucket_name = sagemaker_session.default_bucket() | |
| output_path = f's3://{bucket_name}/jobs' | |
| job_name = 'tensorflow-spot-job' | |
| tf_estimator = TensorFlow(entry_point = 'cifar10-training-sagemaker.py', | |
| role = role, | |
| train_instance_count = 1, | |
| train_instance_type = 'ml.p3.2xlarge', | |
| framework_version = '1.15', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| tf_estimator = TensorFlow(entry_point = 'cifar10-training-sagemaker.py', | |
| ... | |
| ... | |
| checkpoint_s3_uri = tf_estimator.checkpoint_s3_uri, | |
| train_use_spot_instances = True, | |
| train_max_wait = 7200) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from smexperiments.experiment import Experiment | |
| sm = boto3.Session().client('sagemaker') | |
| training_experiment = Experiment.create( | |
| experiment_name = f"cifar10-training-experiment", | |
| description = "Hypothesis: If I use my custom image classification model, it will deliver better accuracy compared to a ResNet50 model on the CIFAR10 dataset", | |
| sagemaker_boto_client=sm) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| hyperparam_options = {'optimizer': ['adam', 'sgd', 'rmsprop'], | |
| 'model': ['resnet', 'custom'], | |
| 'epochs': [30, 60, 120]} | |
| hypnames, hypvalues = zip(*hyperparam_options.items()) | |
| trial_hyperparameter_set = [dict(zip(hypnames, h)) for h in itertools.product(*hypvalues)] | |
| trial_hyperparameter_set |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| static_hyperparams={'batch-size' : 128, | |
| 'learning-rate': 0.001, | |
| 'weight-decay' : 1e-6, | |
| 'momentum' : 0.9} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with Tracker.create(display_name="experiment-metadata", | |
| artifact_bucket=bucket_name, | |
| artifact_prefix=training_experiment.experiment_name, | |
| sagemaker_boto_client=sm) as exp_tracker: | |
| exp_tracker.log_input(name="cifar10-dataset", media_type="s3/uri", value=datasets) | |
| exp_tracker.log_parameters(static_hyperparams) | |
| exp_tracker.log_parameters(hyperparam_options) | |
| exp_tracker.log_artifact(file_path='generate_cifar10_tfrecords.py') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| for trial_hyp in trial_hyperparameter_set: | |
| # Combine static hyperparameters and trial specific hyperparameters | |
| hyperparams = {**static_hyperparams, **trial_hyp} | |
| # Create unique job name with hyperparameter and time | |
| time_append = int(time.time()) | |
| hyp_append = "-".join([str(elm) for elm in trial_hyp.values()]) | |
| job_name = f'cifar10-training-{hyp_append}-{time_append}' | |
| # Create a Tracker to track Trial specific hyperparameters |