Last active
February 25, 2024 16:23
-
-
Save bertsky/9b9892bbbf4c7f84d5cd1375b53d8cf4 to your computer and use it in GitHub Desktop.
dump user metadata of a kraken model file or fix it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Dump user metadata of a kraken model file or fix it. | |
import click | |
import json | |
import os | |
if not 'TF_CPP_MIN_LOG_LEVEL' in os.environ: | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # error | |
from kraken.lib import vgsl | |
import numpy as np | |
is_verbose = 0 | |
@click.group() | |
@click.option('-v', '--verbose', count=True) | |
@click.version_option(version='0.0') | |
def cli(verbose): | |
is_verbose = verbose | |
# @click.argument('command', type=click.Choice(['dump'])) | |
@cli.command(help='Dump user metadata of Core ML file') | |
@click.argument('models', nargs=-1, type=click.Path(exists=True)) | |
# def cli(command, models, out_dir, level, image_format, page_version, text, font): | |
def dump(models): | |
""" MODEL: model file(s) """ | |
for filename in models: | |
m = vgsl.TorchVGSLModel.load_model(filename) | |
metadata = m.user_metadata | |
print(json.dumps(metadata, sort_keys=True, indent=4)) | |
@cli.command(help='Extract training/validation metrics from model checkpoints') | |
@click.option('-p', '--plot', help="render course of metrics as an image file", type=click.Path()) | |
@click.argument('models', nargs=-1, type=click.Path(exists=True)) | |
# def cli(command, models, out_dir, level, image_format, page_version, text, font): | |
def metrics(plot, models): | |
""" MODEL: model file(s) """ | |
trnlsss = [] | |
valaccs = [] | |
print("train_loss\tval_accuracy") | |
for filename in models: | |
m = vgsl.TorchVGSLModel.load_model(filename) | |
metadata = m.user_metadata | |
trnlss = np.sum([val['train_loss'] for _, val in metadata['metrics']]) | |
valacc = np.average([val['val_accuracy'] for _, val in metadata['metrics']]) | |
print(f"{trnlss}\t{valacc}") | |
trnlsss.append(trnlss) | |
valaccs.append(valacc) | |
if plot: | |
# train_loss seems to be accumulative: | |
trnlsss = np.diff(trnlsss) | |
from matplotlib import pyplot as plt | |
fig, ax1 = plt.subplots() | |
color = 'tab:blue' | |
ax1.set_xlabel('epochs') | |
ax1.set_ylabel('log loss', color=color) | |
ax1.plot(list(range(len(trnlsss))), np.log(trnlsss), 'x-', color=color, label='train_loss') | |
ax1.tick_params(axis='y', labelcolor=color) | |
ax1.grid(axis='y') | |
ax2 = ax1.twinx() | |
color = 'tab:red' | |
ax2.set_ylabel('accuracy', color=color) | |
ax2.plot(list(range(len(valaccs))), valaccs, 'o-', color=color, label='val_accuracy') | |
ax2.tick_params(axis='y', labelcolor=color) | |
ax2.grid(axis='y') | |
#fig.tight_layout() | |
fig.legend() | |
plt.title(os.path.commonprefix(models)) | |
plt.savefig(plot) | |
@cli.command(help='Fix Core ML file') | |
@click.argument('model', nargs=-1, type=click.Path(exists=True)) | |
def fix(model): | |
""" MODEL: model file """ | |
for filename in model: | |
m = vgsl.TorchVGSLModel.load_model(filename) | |
metadata = m.user_metadata | |
accuracy = metadata.get('accuracy') | |
metrics = metadata.get('metrics') | |
needs_fix = False | |
for metric in metrics: | |
if not 'val_metric' in metric[1]: | |
needs_fix = True | |
val_accuracy = metric[1]['val_accuracy'] | |
metric[1]['val_metric'] = val_accuracy | |
for a in accuracy: | |
if a[0] == metric[0] and a[1] < 0.0: | |
a[1] = val_accuracy | |
break | |
if needs_fix: | |
print(f'Fixing {filename}') | |
m.save_model(f'fixed/{filename}') | |
statinfo = os.stat(filename) | |
os.utime(f'fixed/{filename}', ns=(statinfo.st_atime_ns, statinfo.st_mtime_ns)) | |
else: | |
print(f'No fix {filename}') | |
if __name__ == '__main__': | |
cli() |
Author
bertsky
commented
Jan 10, 2024
- based on https://ub-backup.bib.uni-mannheim.de/~stweil/tesstrain/kraken/mlmodel.py
- with additional model metadata (input, one_channel_mode, spec)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment