Skip to content

Instantly share code, notes, and snippets.

@marctemp
Last active March 5, 2021 16:17
Show Gist options
  • Select an option

  • Save marctemp/06bed4a4d243930a1db6bccbd02d8346 to your computer and use it in GitHub Desktop.

Select an option

Save marctemp/06bed4a4d243930a1db6bccbd02d8346 to your computer and use it in GitHub Desktop.
Applied Innovation BERT Fine-Tuning Project
from sagemaker.estimator import Estimator
estimator = Estimator(image_uri='<ECR REPO URL>:latest',
role='<SAGEMAKER ROLE ARN>',
base_job_name='bert-training-job',
instance_count=1,
instance_type='ml.m5.2xlarge',
source_dir='/var/bert/model_pkg/',
entry_point='model_train.py'
)
estimator.fit()
from json import dumps
from boto3 import client as boto_client
from utils import time_decorator
S3_CLIENT = boto_client('s3')
BUCKET = 'sagemaker-transformer-transfer-learning'
MODEL_CONFIG_FILE = 'config.json'
MODEL_STATE_DICT_FILE = 'state_dict.json'
MODEL_VOCAB_FILE = 'vocab.txt'
MODEL_METRICS_FILE = 'metrics.json'
@time_decorator
def save_model_to_s3(key, model, s3_client=S3_CLIENT, bucket=BUCKET, model_file=MODEL_STATE_DICT_FILE):
try:
print(f'Saving {model} to S3: {bucket}/{key}...')
state_dict = model.state_dict()
state_dict = dict(zip(state_dict.keys(), [t.cpu().tolist() for t in list(state_dict.values())]))
print(f'Saving state dictionary -- this may take 30 mins (over 2GB)...')
s3_client.put_object(Body=dumps(state_dict).encode('utf-8'),
Bucket=bucket,
Key=f'{key}/{model_file}'
)
print(f'State dictionary saved to S3: {bucket}/{key}')
except:
print(f'Issue running {save_model_to_s3.__name__}')
@time_decorator
def save_config_to_s3(key, config, s3_client=S3_CLIENT, bucket=BUCKET, config_file=MODEL_CONFIG_FILE):
try:
print(f'Saving {config} to S3: {bucket}/{key}...')
s3_client.put_object(Body=dumps(config).encode('utf-8'),
Bucket=bucket,
Key=f'{key}/{config_file}'
)
print(f'Config saved to S3: {bucket}/{key}')
except:
print(f'Issue running {save_config_to_s3.__name__}')
@time_decorator
def save_vocab_to_s3(key, vocab, s3_client=S3_CLIENT, bucket=BUCKET, vocab_file=MODEL_VOCAB_FILE):
try:
print(f'Saving {vocab} to S3: {bucket}/{key}...')
byte_vocab = b''
for v in vocab:
byte_vocab += v.encode('utf-8')
s3_client.put_object(Body=byte_vocab,
Bucket=bucket,
Key=f'{key}/{vocab_file}'
)
print(f'Vocab saved to S3: {bucket}/{key}')
except:
print(f'Issue running {save_vocab_to_s3.__name__}')
@time_decorator
def save_metrics_to_s3(key, metrics, s3_client=S3_CLIENT, bucket=BUCKET, metrics_file=MODEL_METRICS_FILE):
try:
s3_client.put_object(Body=dumps(metrics).encode('utf-8'),
Bucket=bucket,
Key=f'{key}/{metrics_file}'
)
print(f'Metrics saved to S3: {bucket}/{key}')
except:
print(f'Issue running {save_metrics_to_s3.__name__}')
@time_decorator
def save_to_s3(key, model, config, vocab, metrics, s3_client=S3_CLIENT, bucket=BUCKET):
try:
save_config_to_s3(key, config)
except:
pass
try:
save_vocab_to_s3(key, vocab)
except:
pass
try:
save_model_to_s3(key, model)
except:
pass
try:
save_metrics_to_s3(key, metrics)
except:
pass
if not os.path.isfile(vocab_file):
print('Trying to read from string')
vocab = collections.OrderedDict()
tokens = vocab_file.split('\n')
for index, token in enumerate(tokens):
vocab[token] = index
self.vocab = vocab
else:
self.vocab = load_vocab(vocab_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment