This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import save_mxnet_model | |
prefix = '/opt/ml/' | |
model_path = os.path.join(prefix, 'model') | |
model_name = 'mnist-cnn-'+str(epochs) | |
model.save(model_path+'/'+model_name+'.hd5') # Keras model | |
save_mxnet_model(model=model, prefix=model_path+'/'+model_name) # MXNet model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.utils import multi_gpu_model | |
model = Sequential() | |
model.add(...) | |
... | |
if gpu_count > 1: | |
model = multi_gpu_model(model, gpus=gpu_count) | |
model.compile(...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_data(input_path): | |
# Adapted from https://github.com/keras-team/keras/blob/master/keras/datasets/fashion_mnist.py | |
files = ['training/train-labels-idx1-ubyte.gz', | |
'training/train-images-idx3-ubyte.gz', | |
'validation/t10k-labels-idx1-ubyte.gz', | |
'validation/t10k-images-idx3-ubyte.gz'] | |
# Load training labels | |
with gzip.open(input_path+files[0], 'rb') as lbpath: | |
y_train = np.frombuffer( | |
lbpath.read(), np.uint8, offset=8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prefix = '/opt/ml/' | |
param_path = os.path.join(prefix, 'input/config/hyperparameters.json') | |
with open(param_path, 'r') as params: | |
hyperParams = json.load(params) | |
lr = float(hyperParams.get('lr', '0.1')) | |
batch_size = int(hyperParams.get('batch_size', '128')) | |
epochs = int(hyperParams.get('epochs', '10')) | |
gpu_count = int(hyperParams.get('gpu_count', '0')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name) | |
image_name = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name) | |
estimator = sagemaker.estimator.Estimator( | |
image_name=image_name, | |
base_job_name=base_job_name, | |
role=role, | |
train_instance_count=1, | |
train_instance_type=train_instance_type, | |
output_path=output_path, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local_directory = 'data' | |
prefix = repo_name+'/input' | |
train_input_path = sess.upload_data( | |
local_directory+'/train/', key_prefix=prefix+'/train') | |
validation_input_path = sess.upload_data( | |
local_directory+'/validation/', key_prefix=prefix+'/validation') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
docker build -t $image_tag -f $dockerfile . | |
docker tag $image_tag $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest | |
docker push $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aws ecr describe-repositories --repository-names $repo_name > /dev/null 2>&1 | |
if [ $? -ne 0 ] | |
then | |
aws ecr create-repository --repository-name $repo_name > /dev/null | |
fi | |
$(aws ecr get-login --region $region --no-include-email) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM nvidia/cuda:9.0-runtime | |
RUN apt-get update && \ | |
apt-get -y install build-essential libopencv-dev libopenblas-dev libjemalloc-dev libgfortran3 \ | |
python-dev python3-dev python3-pip wget curl | |
COPY mnist_cnn.py /opt/program/train | |
RUN chmod +x /opt/program/train | |
RUN mkdir /root/.keras |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"epsilon": 1e-07, | |
"floatx": "float32", | |
"image_data_format": "channels_first", | |
"backend": "mxnet" | |
} |