Last active
May 17, 2023 03:14
-
-
Save kurtlawrence/7921fca8751ce2ea4826f8ab0a90ee32 to your computer and use it in GitHub Desktop.
Construct HTML document with Plotly charts of burn-rs training progression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! A small utility to construct a HTML document with Plotly charts of the training metrics | |
//! by reading the log entries in burn's artifact directory. | |
//! | |
//! The file can be compiled and executed with `rustc` (`-O` for optimised) | |
//! ```sh | |
//! rustc chart-metrics.rs && ./chart-metrics <ARTIFACT-DIR> | |
//! ``` | |
//! | |
//! Source code: https://gist.github.com/kurtlawrence/7921fca8751ce2ea4826f8ab0a90ee32 | |
use std::collections::BTreeMap; | |
use std::fs; | |
use std::path::{Path, PathBuf}; | |
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>; | |
fn main() -> Result<()> { | |
let mut args = std::env::args().skip(1).collect::<Vec<_>>(); | |
let agg_epoch = args | |
.iter() | |
.enumerate() | |
.find_map(|(i, x)| (x == "--avg-epoch").then_some(i)); | |
if let Some(x) = agg_epoch { | |
args.remove(x); | |
} | |
let agg_epoch = agg_epoch.is_some(); | |
let mut args = args.into_iter(); | |
let artifacts_dir = args.next().map(PathBuf::from).ok_or_else(|| { | |
print_usage(); | |
"expecting a burn artifacts directory" | |
})?; | |
let output_file = args.next().map(PathBuf::from); | |
let metrics = read_dir(&artifacts_dir)?; | |
match &output_file { | |
Some(x) => eprintln!("Constructing HTML report at `{}`", x.display()), | |
None => eprintln!("Constructing HTML report"), | |
} | |
let injection = metrics.into_iter().fold(String::new(), |x, (k, v)| { | |
x + "\n\n" + &plot_html(k, v, agg_epoch) | |
}); | |
let html = HTML.replace("{{inject}}", &injection); | |
match output_file { | |
None => println!("{html}"), | |
Some(x) => fs::write(x, html)?, | |
} | |
Ok(()) | |
} | |
fn print_usage() { | |
eprintln!("Read training logs in a burn artifact directory and output a HTML file with plots"); | |
eprintln!("usage: <ARTIFACT-DIR> [<OUTPUT-FILE>] [--avg-epoch]"); | |
eprintln!("--avg-epoch: optionally average the metric per epoch"); | |
} | |
struct Metric { | |
epoch: u32, | |
value: f64, | |
stg: Stage, | |
} | |
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] | |
enum Stage { | |
Train, | |
Valid, | |
} | |
type Map = BTreeMap<String, Vec<Metric>>; | |
fn read_dir(path: &Path) -> Result<Map> { | |
let map = read_stage(BTreeMap::new(), path.join("train").as_ref(), Stage::Train)?; | |
let mut map = read_stage(map, path.join("valid").as_ref(), Stage::Valid)?; | |
// sort by epoch, cycling between train and valid | |
map.values_mut() | |
.for_each(|vs| vs.sort_by(|a, b| a.epoch.cmp(&b.epoch))); | |
Ok(map) | |
} | |
fn read_stage(mut map: Map, path: &Path, stg: Stage) -> Result<Map> { | |
for d in path.read_dir()? { | |
let d = d?; | |
let Some(epoch) = d.file_name().to_str().and_then(|x| x.strip_prefix("epoch-")).and_then(|x| x.parse::<u32>().ok()) else { continue; }; | |
eprintln!("Reading metrics for {stg:?} epoch-{epoch}"); | |
for d in d.path().read_dir()? { | |
let p = d?.path(); | |
let Some(name) = p.file_stem().and_then(|x| x.to_str()).map(ToString::to_string) else { continue; }; | |
let es = map.entry(name).or_default(); | |
parse_values(&p, |value| es.push(Metric { epoch, stg, value }))?; | |
} | |
} | |
Ok(map) | |
} | |
fn parse_values<F: FnMut(f64)>(file: &Path, cb: F) -> Result<()> { | |
let x = fs::read_to_string(file)?; | |
x.lines().filter_map(|x| x.parse::<f64>().ok()).for_each(cb); | |
Ok(()) | |
} | |
fn plot_html(metric: String, values: Vec<Metric>, avg_epoch: bool) -> String { | |
use std::fmt::Write; | |
let mut s = format!( | |
r#" | |
<h2>Plot of {metric} metric</h2> | |
<div id="{metric}" style="height: 600px;"></div> | |
<script> | |
"# | |
); | |
let (ts, vs): (Vec<(usize, f64)>, Vec<(usize, f64)>) = if avg_epoch { | |
Box::new( | |
values | |
.into_iter() | |
.fold(BTreeMap::new(), |mut map, m| { | |
let e: &mut Vec<f64> = map.entry((m.epoch, m.stg)).or_default(); | |
e.push(m.value); | |
map | |
}) | |
.into_iter() | |
.map(|((epoch, stg), vs)| { | |
let value = if vs.is_empty() { | |
0. | |
} else { | |
vs.iter().sum::<f64>() / vs.len() as f64 | |
}; | |
(epoch as usize, Metric { epoch, stg, value }) | |
}), | |
) as Box<dyn Iterator<Item = (usize, Metric)>> | |
} else { | |
// we accumulate the iteration as we go | |
Box::new(values.into_iter().enumerate()) | |
} | |
.fold((Vec::new(), Vec::new()), |(mut ts, mut vs), (x, m)| { | |
match m.stg { | |
Stage::Train => ts.push((x, m.value)), | |
Stage::Valid => vs.push((x, m.value)), | |
} | |
(ts, vs) | |
}); | |
let (xst, yst): (Vec<_>, Vec<_>) = ts.into_iter().unzip(); | |
let (xsv, ysv): (Vec<_>, Vec<_>) = vs.into_iter().unzip(); | |
writeln!( | |
&mut s, | |
r#"var train = {{ | |
x: {xst:?}, | |
y: {yst:?}, | |
mode: mode, | |
type: 'scatter', | |
name: 'Train' | |
}}; | |
var valid = {{ | |
x: {xsv:?}, | |
y: {ysv:?}, | |
mode: mode, | |
type: 'scatter', | |
name: 'Valid' | |
}}; | |
var id = document.getElementById('{metric}'); | |
plots.push(id); | |
Plotly.newPlot(id, [train,valid], {{ xaxis: {{ title: '{title}' }} }}); | |
</script> | |
"#, | |
title = if avg_epoch { "Epoch" } else { "Iteration" } | |
) | |
.ok(); | |
s | |
} | |
const HTML: &str = r#" | |
<!DOCTYPE=HTML> | |
<html> | |
<head> | |
<script src="https://cdn.plot.ly/plotly-2.20.0.min.js" charset="utf-8"></script> | |
<script> | |
var mode = 'markers'; | |
var plots = []; | |
function switchMode(el) { | |
if (mode == 'markers') { | |
mode = 'lines'; | |
el.innerText = 'Markers'; | |
} else { | |
mode = 'markers'; | |
el.innerText = 'Lines'; | |
} | |
plots.forEach(p => Plotly.restyle(p, { mode: mode })); | |
} | |
</script> | |
</head> | |
<body> | |
<button onclick="switchMode(this)">Lines</button> | |
{{inject}} | |
</body> | |
</html> | |
"#; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment