Skip to content

Instantly share code, notes, and snippets.

@bertsky
Last active February 25, 2024 16:23
Show Gist options
  • Save bertsky/9b9892bbbf4c7f84d5cd1375b53d8cf4 to your computer and use it in GitHub Desktop.
Save bertsky/9b9892bbbf4c7f84d5cd1375b53d8cf4 to your computer and use it in GitHub Desktop.
dump user metadata of a kraken model file or fix it
#!/usr/bin/env python3
# Dump user metadata of a kraken model file or fix it.
import click
import json
import os
if not 'TF_CPP_MIN_LOG_LEVEL' in os.environ:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # error
from kraken.lib import vgsl
import numpy as np
is_verbose = 0
@click.group()
@click.option('-v', '--verbose', count=True)
@click.version_option(version='0.0')
def cli(verbose):
is_verbose = verbose
# @click.argument('command', type=click.Choice(['dump']))
@cli.command(help='Dump user metadata of Core ML file')
@click.argument('models', nargs=-1, type=click.Path(exists=True))
# def cli(command, models, out_dir, level, image_format, page_version, text, font):
def dump(models):
""" MODEL: model file(s) """
for filename in models:
m = vgsl.TorchVGSLModel.load_model(filename)
metadata = m.user_metadata
print(json.dumps(metadata, sort_keys=True, indent=4))
@cli.command(help='Extract training/validation metrics from model checkpoints')
@click.option('-p', '--plot', help="render course of metrics as an image file", type=click.Path())
@click.argument('models', nargs=-1, type=click.Path(exists=True))
# def cli(command, models, out_dir, level, image_format, page_version, text, font):
def metrics(plot, models):
""" MODEL: model file(s) """
trnlsss = []
valaccs = []
print("train_loss\tval_accuracy")
for filename in models:
m = vgsl.TorchVGSLModel.load_model(filename)
metadata = m.user_metadata
trnlss = np.sum([val['train_loss'] for _, val in metadata['metrics']])
valacc = np.average([val['val_accuracy'] for _, val in metadata['metrics']])
print(f"{trnlss}\t{valacc}")
trnlsss.append(trnlss)
valaccs.append(valacc)
if plot:
# train_loss seems to be accumulative:
trnlsss = np.diff(trnlsss)
from matplotlib import pyplot as plt
fig, ax1 = plt.subplots()
color = 'tab:blue'
ax1.set_xlabel('epochs')
ax1.set_ylabel('log loss', color=color)
ax1.plot(list(range(len(trnlsss))), np.log(trnlsss), 'x-', color=color, label='train_loss')
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid(axis='y')
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('accuracy', color=color)
ax2.plot(list(range(len(valaccs))), valaccs, 'o-', color=color, label='val_accuracy')
ax2.tick_params(axis='y', labelcolor=color)
ax2.grid(axis='y')
#fig.tight_layout()
fig.legend()
plt.title(os.path.commonprefix(models))
plt.savefig(plot)
@cli.command(help='Fix Core ML file')
@click.argument('model', nargs=-1, type=click.Path(exists=True))
def fix(model):
""" MODEL: model file """
for filename in model:
m = vgsl.TorchVGSLModel.load_model(filename)
metadata = m.user_metadata
accuracy = metadata.get('accuracy')
metrics = metadata.get('metrics')
needs_fix = False
for metric in metrics:
if not 'val_metric' in metric[1]:
needs_fix = True
val_accuracy = metric[1]['val_accuracy']
metric[1]['val_metric'] = val_accuracy
for a in accuracy:
if a[0] == metric[0] and a[1] < 0.0:
a[1] = val_accuracy
break
if needs_fix:
print(f'Fixing {filename}')
m.save_model(f'fixed/{filename}')
statinfo = os.stat(filename)
os.utime(f'fixed/{filename}', ns=(statinfo.st_atime_ns, statinfo.st_mtime_ns))
else:
print(f'No fix {filename}')
if __name__ == '__main__':
cli()
@bertsky
Copy link
Author

bertsky commented Feb 25, 2024

Latest version can now also be used to extract the metrics from multiple checkpoints, e.g.

mlmodel.py metrics -p herrnhut-kurrent.kraken/metrics.png herrnhut-kurrent.kraken/herrnhut-kurrent_{0..46}.mlmodel

metrics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment