Last active
March 1, 2018 21:16
-
-
Save soaxelbrooke/318b9cd7d05c4f1756a1ae0191de4401 to your computer and use it in GitHub Desktop.
Simple sqlite experiment logger
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Class for tracking experiments in a local sqlite database """ | |
class SqliteExperiment: | |
def __init__(self, hparams, metrics, experiment_id=None): | |
self.experiment_id = experiment_id or str(uuid4()) | |
self.hparams = hparams | |
self.metrics = metrics | |
self.metric_names = ['experiment_id', 'measured_at'] + [n for n, t in metrics] | |
self.log_every = int(os.environ.get('LOG_EVERY', 10000)) | |
self.last_log = None | |
self.last_epoch = None | |
self.db = sqlite3.connect('experiments.sqlite') | |
self.ensure_tables() | |
def ensure_tables(self): | |
""" Create tables for metrics and hyper params if they don't exist """ | |
self.db.execute(''' | |
CREATE TABLE IF NOT EXISTS hparams ( | |
experiment_id text primary key, | |
{} | |
) | |
'''.format(self.to_sql_column_defs(self.hparams).strip(','))) | |
self.db.execute(''' | |
CREATE TABLE IF NOT EXISTS metrics ( | |
experiment_id text, | |
measured_at int, | |
{} | |
) | |
'''.format(self.to_sql_column_defs(self.metrics).strip(','))) | |
self.db.commit() | |
@classmethod | |
def to_sqlite_col_type(cls, col_type): | |
return { | |
int: 'integer', | |
float: 'real', | |
str: 'text', | |
bool: 'integer', | |
}[col_type] | |
def to_sql_column_defs(self, spec): | |
return ',\n'.join([ | |
'{} {}'.format(col_name, self.to_sqlite_col_type(col_type)) | |
for col_name, col_type in spec | |
]) + ',' | |
def log_hparams(self, hparams): | |
hparam_values = [self.experiment_id] + [hparams[name] for name, _type in self.hparams] | |
for idx, hparam in enumerate(hparam_values): | |
if isinstance(hparam, list): | |
hparam_values[idx] = ','.join(map(str, hparam)) | |
self.db.execute(''' | |
insert into hparams values ({}) | |
'''.format(', '.join(['?'] * len(hparam_values))), hparam_values) | |
self.db.commit() | |
def should_log(self, epoch, step): | |
should_log = False | |
if (self.last_epoch is None or self.last_log is None) \ | |
or (self.last_epoch < epoch) \ | |
or ((self.last_log + self.log_every) < step): | |
self.last_epoch = epoch | |
self.last_log = step | |
should_log = True | |
return should_log | |
def log_metrics(self, epoch, step, metrics, force=False): | |
if not force and not self.should_log(epoch, step): | |
return | |
metric_values = [self.experiment_id, time.time()] + \ | |
[metrics.get(name) for name, _type in self.metrics] | |
self.db.execute( | |
''' | |
insert into metrics ({}) values ({}) | |
'''.format(', '.join(self.metric_names), ', '.join(['?'] * len(metric_values))), | |
metric_values) | |
self.db.commit() | |
# Example Usage: | |
################ | |
sle = SqliteExperiment( | |
[('vocab_size', int), ('msg_len', int), ('context_dim', int), | |
('embed_dim', int), ('batch_size', int)], | |
[('loss', float), ('dev_loss', float), ('epoch', int), | |
('acc', float), ('dev_acc', float)], | |
os.environ.get('EXPERIMENT_ID')) | |
sle.log_hparams({'vocab_size': 16384, 'msg_len': 100, 'context_dim': 100, | |
'embed_dim': 200, 'batch_size': 512}) | |
def epoch_callback(loss, dev_loss, epoch, acc, dev_acc): | |
sle.log_metrics(epoch, train_x.shape[0], | |
{'loss': loss, 'dev_loss': dev_loss, 'acc': acc, 'dev_acc': dev_acc}) | |
my_model.train(train_x, train_y, callback=epoch_callback) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment