A Pen by Shigeru Kobayashi on CodePen.
Created
September 3, 2018 07:59
-
-
Save kotobuki/8d8ae9b00ddc06a0ea55edfbe005f59f to your computer and use it in GitHub Desktop.
Transfer Learning w/ Data Augmentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
console.clear(); | |
const FOR_TRAINING = 0; | |
const FOR_TESTING = 1; | |
let featureExtractor; | |
let classifier; | |
let status; | |
let fileInput; | |
let testFileInput; | |
let labelInput; | |
let trainButton; | |
let canvas; | |
let halfWidth; | |
let halfHeight; | |
const imageSize = 224; | |
function setup() { | |
canvas = createCanvas(imageSize, imageSize); | |
canvas.position(0, 0); | |
halfWidth = width / 2; | |
halfHeight = height / 2; | |
status = createSpan("Loading the model..."); | |
status.position(0, 230); | |
featureExtractor = ml5.featureExtractor("MobileNet", () => { | |
status.html("Please specify an index and upload image files"); | |
labelInput = createInput("label"); | |
labelInput.position(0, 260); | |
fileInput = createFileInput( | |
file => handleFile(file, FOR_TRAINING), | |
"multiple" | |
); | |
fileInput.position(labelInput.width + 10, 260); | |
trainButton = createButton("Train"); | |
trainButton.mousePressed(trainRequested); | |
trainButton.position(0, 290); | |
testFileInput = createFileInput(file => handleFile(file, FOR_TESTING)); | |
testFileInput.position(0, 320); | |
}); | |
// Note: Sounds a little bit strange for me | |
classifier = featureExtractor.classification(); | |
} | |
// Range for random rotations (in degree) | |
// 画像をランダムに回転する回転範囲(単位は度) | |
let rotationRange = 10; | |
// Fraction of total width | |
// ランダムに水平シフトする範囲(横幅に対する割合) | |
let widthShiftRange = 0.1; | |
// Fraction of total height | |
// ランダムに垂直シフトする範囲(縦幅に対する割合) | |
let heightShiftRange = 0.1; | |
// Range for random zoom | |
// ランダムにズームする範囲 | |
let zoomRange = 0.2; | |
// Randomly flip inputs horizontally | |
// 水平方向に入力をランダムに反転する | |
let horizontalFlip = true; | |
// Randomly flip inputs vertically | |
// 垂直方向に入力をランダムに反転する | |
let verticalFlip = true; | |
function handleFile(file, purpose) { | |
if (file.type !== "image") { | |
status.html("Please upload an image file."); | |
return; | |
} | |
img = loadImage( | |
file.data, | |
img => { | |
if (purpose === FOR_TRAINING) { | |
for (let i = 0; i < 10; i++) { | |
resetMatrix(); | |
background(0); | |
translate(halfWidth, halfHeight); | |
const angle = random(-rotationRange, rotationRange); | |
rotate(angle); | |
const widthShift = | |
imageSize * random(-widthShiftRange, widthShiftRange); | |
const heightShift = | |
imageSize * random(-heightShiftRange, heightShiftRange); | |
const zoom = 1 + random(-zoomRange, zoomRange); | |
let horizontalScale = zoom; | |
if (horizontalFlip && 0.5 <= random()) { | |
horizontalScale *= -1; | |
} | |
let verticalScale = zoom; | |
if (verticalFlip && 0.5 <= random()) { | |
verticalScale *= -1; | |
} | |
scale(horizontalScale, verticalScale); | |
image( | |
img, | |
widthShift - halfWidth, | |
heightShift - halfHeight, | |
width, | |
height | |
); | |
console.log("adding an image for " + labelInput.value()); | |
// We'll be able to supply the canvas directly in the near future | |
// See https://github.com/ml5js/ml5-library/pull/206 | |
const src = canvas.elt.toDataURL(); | |
let tmpImg = createImg(src, "Failed to create", () => { | |
tmpImg.class("image").size(224, 224); | |
classifier.addImage(tmpImg.elt, labelInput.value()); | |
}); | |
tmpImg.remove(); | |
} | |
} else if (purpose === FOR_TESTING) { | |
resetMatrix(); | |
background(0); | |
image(img, 0, 0, width, height); | |
const src = canvas.elt.toDataURL(); | |
let tmpImg = createImg(src, "Failed to create", () => { | |
tmpImg.class("image").size(224, 224); | |
classifier.classify(tmpImg.elt, (err, result) => { | |
if (err) { | |
console.error(err); | |
} | |
// I want to get multiple results, instead of the top one | |
status.html(result); | |
}); | |
tmpImg.remove(); | |
}); | |
} | |
}, | |
error => { | |
console.log(error); | |
} | |
); | |
} | |
function trainRequested() { | |
classifier.train(loss => { | |
if (loss) { | |
status.html("Training... loss: " + loss); | |
} else { | |
status.html("Done training! Upload an image file to test."); | |
} | |
}); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="https://unpkg.com/[email protected]/dist/ml5.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.2/p5.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.2/addons/p5.dom.min.js"></script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment