|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Auto-complete</title> |
|
|
|
<!-- Import TensorFlow.js --> |
|
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script> |
|
<!-- Import tfjs-vis --> |
|
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tfjs-vis.umd.min.js"></script> |
|
<!-- Import visual frameworks --> |
|
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous"> |
|
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.0/css/all.css" integrity="sha384-lZN37f5QGtY3VHgisS14W3ExzMWZxybE1SJSEsQp9S+oqd12jhcu+A56Ebc1zFSJ" crossorigin="anonymous"> |
|
</head> |
|
<body class=""> |
|
<div class="container p-0"> |
|
<div class="container p-0" > |
|
<header > |
|
<nav class="navbar navbar-dark bg-dark rounded"> |
|
<a class="navbar-brand"> |
|
<i class="fas fa-comment"></i> |
|
<strong>Auto-complete</strong> |
|
</a> |
|
</nav> |
|
</header> |
|
</div> |
|
<div class="container mt-3"> |
|
<div class="row"> |
|
<div class="col"> |
|
<label class="form-label">Input Text</label> |
|
<input class="form-control" type="input" id="pred_features" pattern="^[a-z]*$"> |
|
</div> |
|
<div class="col"> |
|
<label class="form-label">Predicted Text</label> |
|
<input class="form-control" type="input" id="pred_labels" disabled> |
|
</div> |
|
</div> |
|
<div> |
|
<div class="card mt-3 bg-light"> |
|
<h5 class="card-header">Training Parameters</h5> |
|
<div class="card-body"> |
|
<div class="row"> |
|
<div class="col-2"> |
|
<label>Max word length</label> |
|
</div> |
|
<div class="col-10"> |
|
<input type="number" id="max_len" onchange="max_len=parseInt(this.value);document.getElementById('pred_features').maxLength=parseInt(this.value)" min=1> |
|
</div> |
|
</div> |
|
<div class="row mt-3"> |
|
<div class="col-2"> |
|
<label>Epochs</label> |
|
</div> |
|
<div class="col-10"> |
|
<input type="number" id="epochs" onchange="epochs=parseInt(this.value)" min=1> |
|
</div> |
|
</div> |
|
<div class="row mt-3"> |
|
<div class="col-2"> |
|
<label>Batch Size</label> |
|
</div> |
|
<div class="col-10"> |
|
<input type="number" id="batch_size" onchange="batch_size=parseInt(this.value)" min=1> |
|
</div> |
|
</div> |
|
<div class="row mt-3"> |
|
<div class="col-2"> |
|
<button class="btn btn-secondary btn-sm" style="cursor:pointer;" type="button" onclick="document.getElementById('file').click()"> |
|
<i class="fa fa-upload" aria-hidden="true"></i> |
|
Dataset |
|
</button> |
|
<input style="display:none;"type="file" id="file" > |
|
</div> |
|
<div class="col-10"> |
|
<label id="file_name">Supported format: csv, tsv, txt.</label> |
|
</div> |
|
</div> |
|
|
|
</div> |
|
</div> |
|
<input class="btn btn-secondary btn-block mt-3" type="button" id="train" value="Train"> |
|
<div class="d-flex justify-content-center mt-3" onclick="showVizer()"> |
|
<span class="spinner-grow spinner-grow" style="display:none" id="status" aria-hidden="true"></span> |
|
</div> |
|
<div class="float-right"> |
|
<a href="https://www.youtube.com/user/chirpieful"> |
|
<i class="fab fa-youtube text-danger" aria-hidden="true"></i> |
|
</a> |
|
<a href="https://ohyicong.medium.com/"> |
|
<i class="fab fa-medium text-dark" aria-hidden="true"></i> |
|
</a> |
|
</div> |
|
</div> |
|
</body> |
|
<!-- main script file --> |
|
<script> |
|
const ALPHA_LEN = 26; |
|
var sample_len = 1; |
|
var batch_size = 32; |
|
var epochs = 250; |
|
var max_len = 10; |
|
var words = []; |
|
var model = create_model(max_len,ALPHA_LEN); |
|
var status = ""; |
|
function setup(){ |
|
document.getElementById('file') |
|
.addEventListener('change', function() { |
|
var fr=new FileReader(); |
|
fr.onload=function(){ |
|
result = fr.result |
|
filesize = result.length |
|
delimiters = ['\r\n',',','\t',' ']; |
|
document.getElementById('file_name').innerText = "Supported format: csv, tsv, txt."; |
|
for (let i in delimiters){ |
|
length = result.split(delimiters[i]).length |
|
if(length!=filesize && length>1){ |
|
words=result.split(delimiters[i]); |
|
document.getElementById('file_name').innerText = document.getElementById('file').files[0].name |
|
} |
|
} |
|
} |
|
fr.readAsText(this.files[0]); |
|
}) |
|
document.getElementById('train') |
|
.addEventListener('click',async ()=>{ |
|
if(words.length<=0){ |
|
alert("No dataset"); |
|
return |
|
} |
|
document.getElementById("status").style.display = "block"; |
|
document.getElementById("train").style.display = "none"; |
|
try{ |
|
filtered_words = preprocessing_stage_1(words,max_len); |
|
int_words = preprocessing_stage_2(filtered_words,max_len); |
|
train_features = preprocessing_stage_3(int_words,max_len,sample_len); |
|
train_labels = preprocessing_stage_4(int_words,max_len,sample_len); |
|
train_features = preprocessing_stage_5(train_features,max_len,ALPHA_LEN); |
|
train_labels = preprocessing_stage_5(train_labels,max_len,ALPHA_LEN); |
|
model = await create_model(max_len,ALPHA_LEN) |
|
await trainModel(model, train_features, train_labels); |
|
await model.save('downloads://autocorrect_model'); |
|
//memory management |
|
train_features.dispose(); |
|
train_labels.dispose(); |
|
}catch (err){ |
|
alert("No enough GPU space. Please reduce your dataset size."); |
|
} |
|
document.getElementById("status").style.display = "none"; |
|
document.getElementById("train").style.display = "block"; |
|
|
|
}) |
|
document.getElementById('pred_features') |
|
.addEventListener('keyup',()=>{ |
|
console.log( document.getElementById('pred_features').value); |
|
let pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); |
|
let pred_features = [] |
|
pred_features.push(document.getElementById('pred_features').value); |
|
if(pred_features[0].length<sample_len+1 || !pattern.test(pred_features[0])){ |
|
document.getElementById('pred_labels').value=""; |
|
return; |
|
} |
|
pred_features = preprocessing_stage_2(pred_features,max_len); |
|
pred_features = preprocessing_stage_5(pred_features,max_len,ALPHA_LEN); |
|
let pred_labels = model.predict(pred_features); |
|
pred_labels = postprocessing_stage_1(pred_labels) |
|
pred_labels = postprocessing_stage_2(pred_labels,max_len)[0] |
|
document.getElementById('pred_labels').value=pred_labels.join(""); |
|
|
|
}) |
|
document.getElementById("max_len").value=max_len |
|
document.getElementById("epochs").value=epochs |
|
document.getElementById("batch_size").value=batch_size |
|
document.getElementById("pred_features").maxLength = document.getElementById("max_len").value; |
|
|
|
} |
|
function showVizer(){ |
|
const visorInstance = tfvis.visor(); |
|
if (!visorInstance.isOpen()) { |
|
visorInstance.toggle(); |
|
} |
|
} |
|
|
|
function preprocessing_stage_1(words,max_len){ |
|
// function to filter the wordlist |
|
// string [] = words |
|
// int = max_len |
|
status = "Preprocessing Data 1"; |
|
console.log(status); |
|
let filtered_words = []; |
|
var pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); |
|
for (let i in words){ |
|
var is_valid = pattern.test(words[i]); |
|
if (is_valid) filtered_words.push(words[i]); |
|
} |
|
return filtered_words; |
|
} |
|
function preprocessing_stage_2(words,max_len){ |
|
// function to convert the wordlist to int |
|
// string [] = words |
|
// int = max_len |
|
status = "Preprocessing Data 2"; |
|
console.log(status); |
|
let int_words = []; |
|
for (let i in words){ |
|
int_words.push(word_to_int(words[i],max_len)) |
|
} |
|
return int_words; |
|
} |
|
function preprocessing_stage_3(words,max_len,sample_len){ |
|
// function to perform sliding window on wordlist |
|
// int [] = words |
|
// int = max_len, sample_len |
|
status = "Preprocessing Data 3"; |
|
console.log(status); |
|
let input_data = []; |
|
for (let x in words){ |
|
let letters = []; |
|
for (let y=sample_len+1;y<max_len+1;y++){ |
|
input_data.push(words[x].slice(0,y).concat(Array(max_len-y).fill(0))); |
|
} |
|
} |
|
return input_data; |
|
} |
|
function preprocessing_stage_4(words,max_len,sample_len){ |
|
// function to ensure that training data size y == x |
|
// int [] = words |
|
// int = max_len, sample_len |
|
status = "Preprocessing Data 4"; |
|
console.log(status); |
|
let output_data = []; |
|
for (let x in words){ |
|
for (let y=sample_len+1;y<max_len+1;y++){ |
|
output_data.push(words[x]); |
|
} |
|
} |
|
return output_data; |
|
} |
|
function preprocessing_stage_5(words,max_len,alpha_len){ |
|
// function to convert int to onehot encoding |
|
// int [] = words |
|
// int = max_len, alpha_len |
|
status = "Preprocessing Data 5"; |
|
console.log(status); |
|
return tf.oneHot(tf.tensor2d(words,[words.length,max_len],dtype='int32'), alpha_len); |
|
} |
|
function postprocessing_stage_1(words){ |
|
//function to decode onehot encoding |
|
return words.argMax(-1).arraySync(); |
|
} |
|
function postprocessing_stage_2(words,max_len){ |
|
//function to convert int to words |
|
let results = []; |
|
for (let i in words){ |
|
results.push(int_to_word(words[i],max_len)); |
|
} |
|
return results; |
|
} |
|
function word_to_int (word,max_len){ |
|
// char [] = word |
|
// int = max_len |
|
let encode = []; |
|
for (let i = 0; i < max_len; i++) { |
|
if(i<word.length){ |
|
let letter = word.slice(i, i+1); |
|
encode.push(letter.charCodeAt(0)-96); |
|
}else{ |
|
encode.push(0) |
|
} |
|
} |
|
return encode; |
|
} |
|
function int_to_word (word,max_len){ |
|
// int [] = word |
|
// int = max_len |
|
let decode = [] |
|
for (let i = 0; i < max_len; i++) { |
|
if(word[i]==0){ |
|
decode.push(""); |
|
}else{ |
|
decode.push(String.fromCharCode(word[i]+96)) |
|
} |
|
|
|
} |
|
return decode; |
|
} |
|
async function create_model(max_len,alpha_len){ |
|
var model = tf.sequential(); |
|
await model.add(tf.layers.lstm({ |
|
units:alpha_len*2, |
|
inputShape:[max_len,alpha_len], |
|
dropout:0.2, |
|
recurrentDropout:0.2, |
|
useBias: true, |
|
returnSequences:true, |
|
activation:"relu" |
|
})) |
|
await model.add(tf.layers.timeDistributed({ |
|
layer: tf.layers.dense({ |
|
units: alpha_len, |
|
dropout:0.2, |
|
activation:"softmax" |
|
}) |
|
})); |
|
model.summary(); |
|
return model |
|
} |
|
async function trainModel(model, train_features, train_labels) { |
|
status = "Training Model"; |
|
console.log(status) |
|
// Prepare the model for training. |
|
model.compile({ |
|
optimizer: tf.train.adam(), |
|
loss: 'categoricalCrossentropy', |
|
metrics: ['mse'] |
|
}) |
|
await model.fit(train_features, train_labels, { |
|
epochs, |
|
batch_size, |
|
shuffle: true, |
|
callbacks: tfvis.show.fitCallbacks( |
|
{ name: 'Training' }, |
|
['loss', 'mse'], |
|
{ height: 200, callbacks: ['onEpochEnd'] } |
|
) |
|
}); |
|
|
|
|
|
return; |
|
} |
|
setup(); |
|
</script> |
|
|
|
</html> |
Perfect use of LSTM on the Web, I loved it!!
Few suggestions for this: