Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created March 24, 2019 14:06
Show Gist options
  • Save NMZivkovic/05a7350af6087d980bcb883a20a725fc to your computer and use it in GitHub Desktop.
Save NMZivkovic/05a7350af6087d980bcb883a20a725fc to your computer and use it in GitHub Desktop.
function predict(model, data, testDataSize = 500) {
const testData = data.nextDataBatch(testDataSize, true);
const testxs = testData.xs.reshape([testDataSize, 28, 28, 1]);
const labels = testData.labels.argMax([-1]);
const preds = model.predict(testxs).argMax([-1]);
testxs.dispose();
return [preds, labels];
}
async function displayAccuracyPerClass(model, data) {
const [preds, labels] = predict(model, data);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = {name: 'Accuracy', tab: 'Evaluation'};
tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
labels.dispose();
}
async function displayConfusionMatrix(model, data) {
const [preds, labels] = predict(model, data);
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
tfvis.render.confusionMatrix(
container, {values: confusionMatrix}, classNames);
labels.dispose();
}
async function evaluateModelFunction(model, data)
{
await displayAccuracyPerClass(model, data);
await displayConfusionMatrix(model, data);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment