Skip to content

Instantly share code, notes, and snippets.

@ohyicong
Last active October 13, 2024 19:26
Show Gist options
  • Save ohyicong/b1e9dab5eec6371b404dbe603ac4685d to your computer and use it in GitHub Desktop.
Save ohyicong/b1e9dab5eec6371b404dbe603ac4685d to your computer and use it in GitHub Desktop.
LSTM Auto-Complete Model

LSTM auto complete model

Train your own auto-complete model on the web with Tensorflow.js

OS support

  1. Chrome Web browser

Usage

  1. Download predicts.html
  2. Upload dataset.txt
  3. Train
  4. Test
awesome,test,grove,one,two,apple
<!DOCTYPE html>
<html>
<head>
<title>Auto-complete</title>
<!-- Import TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<!-- Import tfjs-vis -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tfjs-vis.umd.min.js"></script>
<!-- Import visual frameworks -->
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.0/css/all.css" integrity="sha384-lZN37f5QGtY3VHgisS14W3ExzMWZxybE1SJSEsQp9S+oqd12jhcu+A56Ebc1zFSJ" crossorigin="anonymous">
</head>
<body class="">
<div class="container p-0">
<div class="container p-0" >
<header >
<nav class="navbar navbar-dark bg-dark rounded">
<a class="navbar-brand">
<i class="fas fa-comment"></i>
<strong>Auto-complete</strong>
</a>
</nav>
</header>
</div>
<div class="container mt-3">
<div class="row">
<div class="col">
<label class="form-label">Input Text</label>
<input class="form-control" type="input" id="pred_features" pattern="^[a-z]*$">
</div>
<div class="col">
<label class="form-label">Predicted Text</label>
<input class="form-control" type="input" id="pred_labels" disabled>
</div>
</div>
<div>
<div class="card mt-3 bg-light">
<h5 class="card-header">Training Parameters</h5>
<div class="card-body">
<div class="row">
<div class="col-2">
<label>Max word length</label>
</div>
<div class="col-10">
<input type="number" id="max_len" onchange="max_len=parseInt(this.value);document.getElementById('pred_features').maxLength=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<label>Epochs</label>
</div>
<div class="col-10">
<input type="number" id="epochs" onchange="epochs=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<label>Batch Size</label>
</div>
<div class="col-10">
<input type="number" id="batch_size" onchange="batch_size=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<button class="btn btn-secondary btn-sm" style="cursor:pointer;" type="button" onclick="document.getElementById('file').click()">
<i class="fa fa-upload" aria-hidden="true"></i>
Dataset
</button>
<input style="display:none;"type="file" id="file" >
</div>
<div class="col-10">
<label id="file_name">Supported format: csv, tsv, txt.</label>
</div>
</div>
</div>
</div>
<input class="btn btn-secondary btn-block mt-3" type="button" id="train" value="Train">
<div class="d-flex justify-content-center mt-3" onclick="showVizer()">
<span class="spinner-grow spinner-grow" style="display:none" id="status" aria-hidden="true"></span>
</div>
<div class="float-right">
<a href="https://www.youtube.com/user/chirpieful">
<i class="fab fa-youtube text-danger" aria-hidden="true"></i>
</a>
<a href="https://ohyicong.medium.com/">
<i class="fab fa-medium text-dark" aria-hidden="true"></i>
</a>
</div>
</div>
</body>
<!-- main script file -->
<script>
const ALPHA_LEN = 26;
var sample_len = 1;
var batch_size = 32;
var epochs = 250;
var max_len = 10;
var words = [];
var model = create_model(max_len,ALPHA_LEN);
var status = "";
function setup(){
document.getElementById('file')
.addEventListener('change', function() {
var fr=new FileReader();
fr.onload=function(){
result = fr.result
filesize = result.length
delimiters = ['\r\n',',','\t',' '];
document.getElementById('file_name').innerText = "Supported format: csv, tsv, txt.";
for (let i in delimiters){
length = result.split(delimiters[i]).length
if(length!=filesize && length>1){
words=result.split(delimiters[i]);
document.getElementById('file_name').innerText = document.getElementById('file').files[0].name
}
}
}
fr.readAsText(this.files[0]);
})
document.getElementById('train')
.addEventListener('click',async ()=>{
if(words.length<=0){
alert("No dataset");
return
}
document.getElementById("status").style.display = "block";
document.getElementById("train").style.display = "none";
try{
filtered_words = preprocessing_stage_1(words,max_len);
int_words = preprocessing_stage_2(filtered_words,max_len);
train_features = preprocessing_stage_3(int_words,max_len,sample_len);
train_labels = preprocessing_stage_4(int_words,max_len,sample_len);
train_features = preprocessing_stage_5(train_features,max_len,ALPHA_LEN);
train_labels = preprocessing_stage_5(train_labels,max_len,ALPHA_LEN);
model = await create_model(max_len,ALPHA_LEN)
await trainModel(model, train_features, train_labels);
await model.save('downloads://autocorrect_model');
//memory management
train_features.dispose();
train_labels.dispose();
}catch (err){
alert("No enough GPU space. Please reduce your dataset size.");
}
document.getElementById("status").style.display = "none";
document.getElementById("train").style.display = "block";
})
document.getElementById('pred_features')
.addEventListener('keyup',()=>{
console.log( document.getElementById('pred_features').value);
let pattern = new RegExp("^[a-z]{1,"+max_len+"}$");
let pred_features = []
pred_features.push(document.getElementById('pred_features').value);
if(pred_features[0].length<sample_len+1 || !pattern.test(pred_features[0])){
document.getElementById('pred_labels').value="";
return;
}
pred_features = preprocessing_stage_2(pred_features,max_len);
pred_features = preprocessing_stage_5(pred_features,max_len,ALPHA_LEN);
let pred_labels = model.predict(pred_features);
pred_labels = postprocessing_stage_1(pred_labels)
pred_labels = postprocessing_stage_2(pred_labels,max_len)[0]
document.getElementById('pred_labels').value=pred_labels.join("");
})
document.getElementById("max_len").value=max_len
document.getElementById("epochs").value=epochs
document.getElementById("batch_size").value=batch_size
document.getElementById("pred_features").maxLength = document.getElementById("max_len").value;
}
function showVizer(){
const visorInstance = tfvis.visor();
if (!visorInstance.isOpen()) {
visorInstance.toggle();
}
}
function preprocessing_stage_1(words,max_len){
// function to filter the wordlist
// string [] = words
// int = max_len
status = "Preprocessing Data 1";
console.log(status);
let filtered_words = [];
var pattern = new RegExp("^[a-z]{1,"+max_len+"}$");
for (let i in words){
var is_valid = pattern.test(words[i]);
if (is_valid) filtered_words.push(words[i]);
}
return filtered_words;
}
function preprocessing_stage_2(words,max_len){
// function to convert the wordlist to int
// string [] = words
// int = max_len
status = "Preprocessing Data 2";
console.log(status);
let int_words = [];
for (let i in words){
int_words.push(word_to_int(words[i],max_len))
}
return int_words;
}
function preprocessing_stage_3(words,max_len,sample_len){
// function to perform sliding window on wordlist
// int [] = words
// int = max_len, sample_len
status = "Preprocessing Data 3";
console.log(status);
let input_data = [];
for (let x in words){
let letters = [];
for (let y=sample_len+1;y<max_len+1;y++){
input_data.push(words[x].slice(0,y).concat(Array(max_len-y).fill(0)));
}
}
return input_data;
}
function preprocessing_stage_4(words,max_len,sample_len){
// function to ensure that training data size y == x
// int [] = words
// int = max_len, sample_len
status = "Preprocessing Data 4";
console.log(status);
let output_data = [];
for (let x in words){
for (let y=sample_len+1;y<max_len+1;y++){
output_data.push(words[x]);
}
}
return output_data;
}
function preprocessing_stage_5(words,max_len,alpha_len){
// function to convert int to onehot encoding
// int [] = words
// int = max_len, alpha_len
status = "Preprocessing Data 5";
console.log(status);
return tf.oneHot(tf.tensor2d(words,[words.length,max_len],dtype='int32'), alpha_len);
}
function postprocessing_stage_1(words){
//function to decode onehot encoding
return words.argMax(-1).arraySync();
}
function postprocessing_stage_2(words,max_len){
//function to convert int to words
let results = [];
for (let i in words){
results.push(int_to_word(words[i],max_len));
}
return results;
}
function word_to_int (word,max_len){
// char [] = word
// int = max_len
let encode = [];
for (let i = 0; i < max_len; i++) {
if(i<word.length){
let letter = word.slice(i, i+1);
encode.push(letter.charCodeAt(0)-96);
}else{
encode.push(0)
}
}
return encode;
}
function int_to_word (word,max_len){
// int [] = word
// int = max_len
let decode = []
for (let i = 0; i < max_len; i++) {
if(word[i]==0){
decode.push("");
}else{
decode.push(String.fromCharCode(word[i]+96))
}
}
return decode;
}
async function create_model(max_len,alpha_len){
var model = tf.sequential();
await model.add(tf.layers.lstm({
units:alpha_len*2,
inputShape:[max_len,alpha_len],
dropout:0.2,
recurrentDropout:0.2,
useBias: true,
returnSequences:true,
activation:"relu"
}))
await model.add(tf.layers.timeDistributed({
layer: tf.layers.dense({
units: alpha_len,
dropout:0.2,
activation:"softmax"
})
}));
model.summary();
return model
}
async function trainModel(model, train_features, train_labels) {
status = "Training Model";
console.log(status)
// Prepare the model for training.
model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['mse']
})
await model.fit(train_features, train_labels, {
epochs,
batch_size,
shuffle: true,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
return;
}
setup();
</script>
</html>
@momen-hafez
Copy link

Hi, this job is great, I want to ask if it's possible to develop it to predict sentences not only words also to rank the predictions according to the dataset.
Regards.

@LakshmanKishore
Copy link

LakshmanKishore commented Jul 31, 2024

Perfect use of LSTM on the Web, I loved it!!
Few suggestions for this:

  1. Could have trained an LSTM on a big corpus of dictionary, and gave a big input box to let user write the sentence and whenever user types the word it should show auto completion of the words.
  2. Have a big input box, and as and when user types the words if there are any new words it should store and then ask user to click on the train button to train with the new words.
  3. Could make a chrome extension with the default LSTM which was trained on the dictionary and use the same to access the text box and then perform autocomplete as and when user types in any text box in any website. (which I might work on because of this project, so thanks for that)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment