Train your own auto-complete model on the web with Tensorflow.js
- Chrome Web browser
- Download predicts.html
- Upload dataset.txt
- Train
- Test
awesome,test,grove,one,two,apple |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Auto-complete</title> | |
<!-- Import TensorFlow.js --> | |
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script> | |
<!-- Import tfjs-vis --> | |
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tfjs-vis.umd.min.js"></script> | |
<!-- Import visual frameworks --> | |
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous"> | |
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.0/css/all.css" integrity="sha384-lZN37f5QGtY3VHgisS14W3ExzMWZxybE1SJSEsQp9S+oqd12jhcu+A56Ebc1zFSJ" crossorigin="anonymous"> | |
</head> | |
<body class=""> | |
<div class="container p-0"> | |
<div class="container p-0" > | |
<header > | |
<nav class="navbar navbar-dark bg-dark rounded"> | |
<a class="navbar-brand"> | |
<i class="fas fa-comment"></i> | |
<strong>Auto-complete</strong> | |
</a> | |
</nav> | |
</header> | |
</div> | |
<div class="container mt-3"> | |
<div class="row"> | |
<div class="col"> | |
<label class="form-label">Input Text</label> | |
<input class="form-control" type="input" id="pred_features" pattern="^[a-z]*$"> | |
</div> | |
<div class="col"> | |
<label class="form-label">Predicted Text</label> | |
<input class="form-control" type="input" id="pred_labels" disabled> | |
</div> | |
</div> | |
<div> | |
<div class="card mt-3 bg-light"> | |
<h5 class="card-header">Training Parameters</h5> | |
<div class="card-body"> | |
<div class="row"> | |
<div class="col-2"> | |
<label>Max word length</label> | |
</div> | |
<div class="col-10"> | |
<input type="number" id="max_len" onchange="max_len=parseInt(this.value);document.getElementById('pred_features').maxLength=parseInt(this.value)" min=1> | |
</div> | |
</div> | |
<div class="row mt-3"> | |
<div class="col-2"> | |
<label>Epochs</label> | |
</div> | |
<div class="col-10"> | |
<input type="number" id="epochs" onchange="epochs=parseInt(this.value)" min=1> | |
</div> | |
</div> | |
<div class="row mt-3"> | |
<div class="col-2"> | |
<label>Batch Size</label> | |
</div> | |
<div class="col-10"> | |
<input type="number" id="batch_size" onchange="batch_size=parseInt(this.value)" min=1> | |
</div> | |
</div> | |
<div class="row mt-3"> | |
<div class="col-2"> | |
<button class="btn btn-secondary btn-sm" style="cursor:pointer;" type="button" onclick="document.getElementById('file').click()"> | |
<i class="fa fa-upload" aria-hidden="true"></i> | |
Dataset | |
</button> | |
<input style="display:none;"type="file" id="file" > | |
</div> | |
<div class="col-10"> | |
<label id="file_name">Supported format: csv, tsv, txt.</label> | |
</div> | |
</div> | |
</div> | |
</div> | |
<input class="btn btn-secondary btn-block mt-3" type="button" id="train" value="Train"> | |
<div class="d-flex justify-content-center mt-3" onclick="showVizer()"> | |
<span class="spinner-grow spinner-grow" style="display:none" id="status" aria-hidden="true"></span> | |
</div> | |
<div class="float-right"> | |
<a href="https://www.youtube.com/user/chirpieful"> | |
<i class="fab fa-youtube text-danger" aria-hidden="true"></i> | |
</a> | |
<a href="https://ohyicong.medium.com/"> | |
<i class="fab fa-medium text-dark" aria-hidden="true"></i> | |
</a> | |
</div> | |
</div> | |
</body> | |
<!-- main script file --> | |
<script> | |
const ALPHA_LEN = 26; | |
var sample_len = 1; | |
var batch_size = 32; | |
var epochs = 250; | |
var max_len = 10; | |
var words = []; | |
var model = create_model(max_len,ALPHA_LEN); | |
var status = ""; | |
function setup(){ | |
document.getElementById('file') | |
.addEventListener('change', function() { | |
var fr=new FileReader(); | |
fr.onload=function(){ | |
result = fr.result | |
filesize = result.length | |
delimiters = ['\r\n',',','\t',' ']; | |
document.getElementById('file_name').innerText = "Supported format: csv, tsv, txt."; | |
for (let i in delimiters){ | |
length = result.split(delimiters[i]).length | |
if(length!=filesize && length>1){ | |
words=result.split(delimiters[i]); | |
document.getElementById('file_name').innerText = document.getElementById('file').files[0].name | |
} | |
} | |
} | |
fr.readAsText(this.files[0]); | |
}) | |
document.getElementById('train') | |
.addEventListener('click',async ()=>{ | |
if(words.length<=0){ | |
alert("No dataset"); | |
return | |
} | |
document.getElementById("status").style.display = "block"; | |
document.getElementById("train").style.display = "none"; | |
try{ | |
filtered_words = preprocessing_stage_1(words,max_len); | |
int_words = preprocessing_stage_2(filtered_words,max_len); | |
train_features = preprocessing_stage_3(int_words,max_len,sample_len); | |
train_labels = preprocessing_stage_4(int_words,max_len,sample_len); | |
train_features = preprocessing_stage_5(train_features,max_len,ALPHA_LEN); | |
train_labels = preprocessing_stage_5(train_labels,max_len,ALPHA_LEN); | |
model = await create_model(max_len,ALPHA_LEN) | |
await trainModel(model, train_features, train_labels); | |
await model.save('downloads://autocorrect_model'); | |
//memory management | |
train_features.dispose(); | |
train_labels.dispose(); | |
}catch (err){ | |
alert("No enough GPU space. Please reduce your dataset size."); | |
} | |
document.getElementById("status").style.display = "none"; | |
document.getElementById("train").style.display = "block"; | |
}) | |
document.getElementById('pred_features') | |
.addEventListener('keyup',()=>{ | |
console.log( document.getElementById('pred_features').value); | |
let pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); | |
let pred_features = [] | |
pred_features.push(document.getElementById('pred_features').value); | |
if(pred_features[0].length<sample_len+1 || !pattern.test(pred_features[0])){ | |
document.getElementById('pred_labels').value=""; | |
return; | |
} | |
pred_features = preprocessing_stage_2(pred_features,max_len); | |
pred_features = preprocessing_stage_5(pred_features,max_len,ALPHA_LEN); | |
let pred_labels = model.predict(pred_features); | |
pred_labels = postprocessing_stage_1(pred_labels) | |
pred_labels = postprocessing_stage_2(pred_labels,max_len)[0] | |
document.getElementById('pred_labels').value=pred_labels.join(""); | |
}) | |
document.getElementById("max_len").value=max_len | |
document.getElementById("epochs").value=epochs | |
document.getElementById("batch_size").value=batch_size | |
document.getElementById("pred_features").maxLength = document.getElementById("max_len").value; | |
} | |
function showVizer(){ | |
const visorInstance = tfvis.visor(); | |
if (!visorInstance.isOpen()) { | |
visorInstance.toggle(); | |
} | |
} | |
function preprocessing_stage_1(words,max_len){ | |
// function to filter the wordlist | |
// string [] = words | |
// int = max_len | |
status = "Preprocessing Data 1"; | |
console.log(status); | |
let filtered_words = []; | |
var pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); | |
for (let i in words){ | |
var is_valid = pattern.test(words[i]); | |
if (is_valid) filtered_words.push(words[i]); | |
} | |
return filtered_words; | |
} | |
function preprocessing_stage_2(words,max_len){ | |
// function to convert the wordlist to int | |
// string [] = words | |
// int = max_len | |
status = "Preprocessing Data 2"; | |
console.log(status); | |
let int_words = []; | |
for (let i in words){ | |
int_words.push(word_to_int(words[i],max_len)) | |
} | |
return int_words; | |
} | |
function preprocessing_stage_3(words,max_len,sample_len){ | |
// function to perform sliding window on wordlist | |
// int [] = words | |
// int = max_len, sample_len | |
status = "Preprocessing Data 3"; | |
console.log(status); | |
let input_data = []; | |
for (let x in words){ | |
let letters = []; | |
for (let y=sample_len+1;y<max_len+1;y++){ | |
input_data.push(words[x].slice(0,y).concat(Array(max_len-y).fill(0))); | |
} | |
} | |
return input_data; | |
} | |
function preprocessing_stage_4(words,max_len,sample_len){ | |
// function to ensure that training data size y == x | |
// int [] = words | |
// int = max_len, sample_len | |
status = "Preprocessing Data 4"; | |
console.log(status); | |
let output_data = []; | |
for (let x in words){ | |
for (let y=sample_len+1;y<max_len+1;y++){ | |
output_data.push(words[x]); | |
} | |
} | |
return output_data; | |
} | |
function preprocessing_stage_5(words,max_len,alpha_len){ | |
// function to convert int to onehot encoding | |
// int [] = words | |
// int = max_len, alpha_len | |
status = "Preprocessing Data 5"; | |
console.log(status); | |
return tf.oneHot(tf.tensor2d(words,[words.length,max_len],dtype='int32'), alpha_len); | |
} | |
function postprocessing_stage_1(words){ | |
//function to decode onehot encoding | |
return words.argMax(-1).arraySync(); | |
} | |
function postprocessing_stage_2(words,max_len){ | |
//function to convert int to words | |
let results = []; | |
for (let i in words){ | |
results.push(int_to_word(words[i],max_len)); | |
} | |
return results; | |
} | |
function word_to_int (word,max_len){ | |
// char [] = word | |
// int = max_len | |
let encode = []; | |
for (let i = 0; i < max_len; i++) { | |
if(i<word.length){ | |
let letter = word.slice(i, i+1); | |
encode.push(letter.charCodeAt(0)-96); | |
}else{ | |
encode.push(0) | |
} | |
} | |
return encode; | |
} | |
function int_to_word (word,max_len){ | |
// int [] = word | |
// int = max_len | |
let decode = [] | |
for (let i = 0; i < max_len; i++) { | |
if(word[i]==0){ | |
decode.push(""); | |
}else{ | |
decode.push(String.fromCharCode(word[i]+96)) | |
} | |
} | |
return decode; | |
} | |
async function create_model(max_len,alpha_len){ | |
var model = tf.sequential(); | |
await model.add(tf.layers.lstm({ | |
units:alpha_len*2, | |
inputShape:[max_len,alpha_len], | |
dropout:0.2, | |
recurrentDropout:0.2, | |
useBias: true, | |
returnSequences:true, | |
activation:"relu" | |
})) | |
await model.add(tf.layers.timeDistributed({ | |
layer: tf.layers.dense({ | |
units: alpha_len, | |
dropout:0.2, | |
activation:"softmax" | |
}) | |
})); | |
model.summary(); | |
return model | |
} | |
async function trainModel(model, train_features, train_labels) { | |
status = "Training Model"; | |
console.log(status) | |
// Prepare the model for training. | |
model.compile({ | |
optimizer: tf.train.adam(), | |
loss: 'categoricalCrossentropy', | |
metrics: ['mse'] | |
}) | |
await model.fit(train_features, train_labels, { | |
epochs, | |
batch_size, | |
shuffle: true, | |
callbacks: tfvis.show.fitCallbacks( | |
{ name: 'Training' }, | |
['loss', 'mse'], | |
{ height: 200, callbacks: ['onEpochEnd'] } | |
) | |
}); | |
return; | |
} | |
setup(); | |
</script> | |
</html> |
Perfect use of LSTM on the Web, I loved it!!
Few suggestions for this:
Hi, this job is great, I want to ask if it's possible to develop it to predict sentences not only words also to rank the predictions according to the dataset.
Regards.