Skip to content

Instantly share code, notes, and snippets.

@N8python
Created September 23, 2020 21:43
Show Gist options
  • Save N8python/93139a670408b163100207fd39430c8c to your computer and use it in GitHub Desktop.
Save N8python/93139a670408b163100207fd39430c8c to your computer and use it in GitHub Desktop.
async function main() {
const fs = require("fs");
const R = require("ramda");
const tf = require("@tensorflow/tfjs-node");
const fsExtra = require('fs-extra');
const text = fs.readFileSync("input.txt").toString();
const chars = Array.from(new Set(text.split("")));
const encoding = Object.fromEntries(chars.map((x, i) => [x, i]));
const decoding = Object.fromEntries(chars.map((x, i) => [i, x]));
const sampleLength = 50;
const epochSize = 2000;
let currEpochIndex = 0;
let data = [];
let labels = [];
if (!fs.existsSync("outputs")) {
fs.mkdirSync("outputs");
} else {
fsExtra.emptyDirSync("outputs")
}
function oneHotEncode(char) {
const vec = Array(chars.length).fill(0);
vec[encoding[char]] = 1;
return vec;
}
function sample(probs, temperature) {
return tf.tidy(() => {
const logits = tf.div(tf.log(probs), Math.max(temperature, 1e-6));
const isNormalized = false;
// `logits` is for a multinomial distribution, scaled by the temperature.
// We randomly draw a sample from the distribution.
return tf.multinomial(logits, 1, null, isNormalized).dataSync()[0];
});
}
const charList = text.split("").map(oneHotEncode);
for (let i = 0; i < charList.length - sampleLength; i++) {
data.push(charList.slice(i, i + sampleLength));
labels.push(charList[i + sampleLength]);
}
let trainData = tf.tensor(data.slice(currEpochIndex, currEpochIndex + epochSize));
let trainLabels = tf.tensor(labels.slice(currEpochIndex, currEpochIndex + epochSize));
const model = tf.sequential({
layers: [
tf.layers.lstm({ inputShape: [null, chars.length], units: 512, activation: "relu", returnSequences: true }),
tf.layers.lstm({ units: 512, activation: "relu", returnSequences: true }),
tf.layers.lstm({ units: 512, activation: "relu", returnSequences: false }),
tf.layers.dense({ units: chars.length, activation: "softmax" }),
]
});
function outputText(length) {
let sentence = [chars[Math.floor(Math.random() * chars.length)]];
let context = [oneHotEncode(sentence[0])];
for (let i = 0; i < length - 1; i++) {
const output = Array.from(model.predict(tf.tensor3d([context])).dataSync());
const max = Math.max(...output);
const idx = sample(tf.squeeze(output), 0.5); //output.findIndex(x => x === max);
sentence.push(decoding[idx]);
context.push(Array(chars.length).fill(undefined).map((_, i) => i === idx ? 1 : 0));
if (context.length > sampleLength) {
context.shift();
}
}
return sentence.join("");
}
model.compile({
optimizer: "adam",
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
clipValue: 0.5,
learningRate: 0.0001
})
let epochAmt = 500;
function fitModel(epochNum = 0) {
model.fit(trainData, trainLabels, {
epochs: 1,
batchSize: 128,
callbacks: {
onBatchEnd(batch, logs) {
console.log(logs);
//console.log(outputText(100));
},
onTrainEnd(logs) {
console.log("EPOCH OVER");
currEpochIndex += epochSize;
if (currEpochIndex >= data.length - epochSize * 2) {
currEpochIndex = 0;
}
trainData = tf.tensor(data.slice(currEpochIndex, currEpochIndex + epochSize));
trainLabels = tf.tensor(labels.slice(currEpochIndex, currEpochIndex + epochSize));
if ((epochNum + 1) % 10 === 0) {
fs.writeFileSync(`outputs/epoch${epochNum + 1}.txt`, outputText(1000));
} else {
fs.writeFileSync(`outputs/epoch${epochNum + 1}.txt`, outputText(100));
}
if (epochNum < epochAmt) {
setTimeout(() => {
fitModel(epochNum + 1);
}, 0)
} else {
if (logs && (logs.loss === logs.loss)) {
(async() => { await model.save(`file://./model`); })();
}
}
}
}
})
}
fitModel();
}
main();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment