Last active
May 17, 2024 02:54
-
-
Save YankeeTube/ee96f60f57b9038ee0b703fc6620e7d9 to your computer and use it in GitHub Desktop.
So Very Fast NSFWJS on TFJS WASM + Web Worker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<html> | |
<head></head> | |
<body> | |
<div> | |
<input type="file" id="file-input" /> | |
</div> | |
</body> | |
<script> | |
const worker = new Worker('worker.js'); | |
worker.postMessage('init') | |
worker.addEventListener('message', nsfwResult) | |
document.addEventListener('DOMContentLoaded', () => { | |
const fileInput = document.querySelector('#file-input'); | |
fileInput.addEventListener('change', imageHandler); | |
}); | |
function nsfwResult({data}) { | |
console.log(data) | |
} | |
async function imageHandler(e) { | |
const file = e.target.files[0]; | |
worker.postMessage(file); | |
} | |
</script> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// download group1-shard1of1, model.json | |
// https://github.com/infinitered/nsfwjs/tree/master/example/nsfw_demo/public/quant_nsfw_mobilenet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"); | |
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/dist/tf-backend-wasm.js"); | |
tf.wasm.setWasmPaths("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/wasm-out/"); | |
tf.enableProdMode() | |
let model; | |
const SIZE = 224; | |
const NSFW_CLASSES = { | |
0: 'Drawing', | |
1: 'Hentai', | |
2: 'Neutral', | |
3: 'Porn', | |
4: 'Sexy' | |
} | |
function nsfwProcess(values) { | |
const topK = 5; | |
const result = {} | |
const valuesAndIndices = []; | |
const topkValues = new Float32Array(topK); | |
const topkIndices = new Int32Array(topK); | |
for (let i = 0; i < values.length; i++) { | |
valuesAndIndices.push({ value: values[i], index: i }); | |
} | |
valuesAndIndices.sort((a, b) => b.value - a.value); | |
for (let i = 0; i < topK; i++) { | |
topkValues[i] = valuesAndIndices[i].value; | |
topkIndices[i] = valuesAndIndices[i].index; | |
} | |
for (let i=0;i<5;i++) { | |
result[NSFW_CLASSES[[topkIndices[i]]]] = Number.parseFloat((topkValues[i] * 100).toFixed(2)) | |
} | |
return result; | |
} | |
async function detectNSFW(bitmap) { | |
const {width: w, height: h} = bitmap; | |
const offScreen = new OffscreenCanvas(w,h); | |
const ctx = offScreen.getContext('2d'); | |
ctx.drawImage(bitmap, 0, 0, w, h); | |
const canvasData = ctx.getImageData(0, 0, w,h).data; | |
const img = new ImageData(canvasData, w, h); | |
const pixels = tf.browser.fromPixels(img); | |
const normalized = pixels.toFloat().div(tf.scalar(255)); | |
let resized = normalized; | |
if (pixels.shape[0] !== SIZE || pixels.shape[1] !== SIZE) { | |
resized = tf.image.resizeBilinear(normalized, [SIZE, SIZE], true); | |
} | |
const batched = resized.reshape([1, SIZE, SIZE, 3]); | |
const predictions = await model.predict(batched); | |
const values = await predictions.data(); | |
const result = nsfwProcess(values); | |
predictions.dispose(); | |
console.log(result); | |
self.postMessage(result); | |
} | |
async function init({data}) { | |
if (typeof data === 'string' && data === 'init') { | |
await tf.setBackend('wasm'); | |
try { | |
model = await tf.loadLayersModel('indexeddb://model'); | |
console.log('Load NSFW Model!'); | |
} catch(e) { | |
model = await tf.loadLayersModel('models/model.json'); | |
model.save('indexeddb://model'); | |
console.log('Save NSFW Model!'); | |
} finally { | |
// warm up | |
const result = tf.tidy(() => model.predict(tf.zeros([1, SIZE, SIZE, 3]))); | |
await result.data(); | |
result.dispose(); | |
} | |
return | |
} | |
const bitmap = await createImageBitmap(data); | |
detectNSFW(bitmap); | |
} | |
addEventListener('message', init) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment