This is a simple sample to forcast a time serials data of air passangers
Created
May 20, 2019 05:44
-
-
Save gangtao/7d8c5e62e7bf2605943489e31e974953 to your computer and use it in GitHub Desktop.
TensorflowJS Time serials Forcast
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<div class="container-fluid" id="main"> | |
<div class="row"> | |
<div class="col-6"> | |
<div id="dataView"></div> | |
</div> | |
</div> | |
<div class="row"> | |
<div id="modelView"></div> | |
</div> | |
<div class="row"> | |
<div id="predictionView"></div> | |
</div> | |
</div> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const STEP_SIZE = 3; | |
const STEP_NUM = 24; | |
const STEP_OFFSET = 1; | |
const TARGET_SIZE = 12; | |
const LSTM_UNITS = 30; | |
const X_LEN = STEP_SIZE + STEP_OFFSET * (STEP_NUM - 1); | |
const Y_LEN = TARGET_SIZE; | |
const config = { epochs: 30, batchSize: 4 }; | |
async function loadData(path) { | |
return await d3.csv(path); | |
} | |
function drawChart(data) { | |
$('#dataView').empty(); | |
var chart = new G2.Chart({ | |
container: "dataView", | |
forceFit: true, | |
height: 300 | |
}); | |
chart.source(data); | |
chart.scale("value", { | |
min: 0, | |
max:800 | |
}); | |
chart.scale("predictValue", { | |
min: 0, | |
max:800 | |
}); | |
chart.scale("time", { | |
tickCount: 10, | |
type: "time" | |
}); | |
chart.line().position("time*value").color("isPrediction"); | |
chart.render(); | |
} | |
function buildX(data, stepSize, stepNum, stepOffset) { | |
let xData = []; | |
for (let n = 0; n < stepNum; n += 1) { | |
const startIndex = n * stepOffset; | |
const endIndex = startIndex + stepSize; | |
const item = data.slice(startIndex, endIndex).map(obj => obj.value); | |
xData.push(item); | |
} | |
return xData; | |
} | |
function makeTrainData(data, stepSize, stepNum, stepOffset, targetSize) { | |
const trainData = { x: [], y: [] }; | |
const length = data.length; | |
const xLength = stepSize + stepOffset * (stepNum - 1); | |
const yLength = targetSize; | |
const stopIndex = length - xLength - yLength; | |
for (let i = 0; i < stopIndex; i += 1) { | |
const x = data.slice(i, i + xLength); | |
const y = data.slice(i + xLength + 1, i + xLength + 1 + yLength); | |
const xData = buildX(x, stepSize, stepNum, stepOffset); | |
const yData = y; | |
trainData.x.push(xData.map(item => item)); | |
trainData.y.push(yData.map(item => item.value)); | |
} | |
return trainData; | |
} | |
function buildModel() { | |
const model = tf.sequential(); | |
//lstm input layer | |
const hidden1 = tf.layers.lstm({ | |
units: LSTM_UNITS, | |
inputShape: [STEP_NUM, STEP_SIZE], | |
returnSequences: true | |
}); | |
model.add(hidden1); | |
/* | |
const dropout = tf.layers.dropout({ | |
rate: 0.2 | |
}); | |
model.add(dropout); | |
const hidden2 = tf.layers.lstm({ | |
units: LSTM_UNITS, | |
returnSequences: true | |
}); | |
model.add(hidden2); | |
*/ | |
//2nd lstm layer | |
const output = tf.layers.lstm({ | |
units: TARGET_SIZE, | |
returnSequences: false | |
}); | |
model.add(output); | |
model.add( | |
tf.layers.dense({ | |
units: TARGET_SIZE | |
}) | |
); | |
model.add(tf.layers.activation({ activation: "tanh" })); | |
//compile | |
const rmsprop = tf.train.rmsprop(0.005); | |
model.compile({ | |
optimizer: rmsprop, | |
loss: tf.losses.meanSquaredError | |
}); | |
return model; | |
} | |
async function predict(model, input) { | |
const prediction = await model.predict(tf.tensor([input])).data(); | |
return prediction; | |
} | |
async function watchTraining() { | |
const metrics = ["loss", "val_loss", "acc", "val_acc"]; | |
const container = { | |
name: "show.fitCallbacks", | |
tab: "Training", | |
styles: { | |
height: "1000px" | |
} | |
}; | |
const callbacks = tfvis.show.fitCallbacks(container, metrics); | |
return train(model, data, callbacks); | |
} | |
async function trainBatch(data, model) { | |
const metrics = ["loss", "val_loss", "acc", "val_acc"]; | |
const container = { | |
name: "show.fitCallbacks", | |
tab: "Training", | |
styles: { | |
height: "1000px" | |
} | |
}; | |
const callbacks = tfvis.show.fitCallbacks(container, metrics); | |
console.log("training start!"); | |
tfvis.visor(); | |
// Save the model | |
// const saveResults = await model.save('downloads://air-time-model'); | |
const epochs = config.epochs; | |
const results = []; | |
const xs = tf.tensor3d(data.x); | |
const ys = tf.tensor2d(data.y); | |
const history = await model.fit(xs, ys, { | |
batchSize: config.batchSize, | |
epochs: config.epochs, | |
validationSplit: 0.2, | |
callbacks: callbacks | |
}); | |
console.log("training complete!"); | |
return history; | |
} | |
$(async function() { | |
const airPassagnerData = await loadData( | |
"https://cdn.jsdelivr.net/gh/gangtao/datasets@master/csv/air_passengers.csv" | |
); | |
const chartData = airPassagnerData.map(item => ({ | |
time: item.Date, | |
value: parseInt(item.Number) | |
})); | |
drawChart(chartData); | |
// Normalize data with value change | |
let changeData = []; | |
for (let i = 1; i < airPassagnerData.length; i++) { | |
const item = {}; | |
item.date = airPassagnerData[i].Date; | |
const val = parseInt(airPassagnerData[i].Number); | |
const val0 = parseInt(airPassagnerData[i - 1].Number); | |
item.value = val / val0 - 1; | |
changeData.push(item); | |
} | |
const trainData = makeTrainData( | |
changeData, | |
STEP_SIZE, | |
STEP_NUM, | |
STEP_OFFSET, | |
TARGET_SIZE | |
); | |
const model = buildModel(); | |
model.summary(); | |
const history = await trainBatch(trainData, model); | |
const inputStart = changeData.length - X_LEN; | |
const inputEnd = changeData.length; | |
const input = changeData.slice(inputStart, inputEnd); | |
const predictInput = buildX(input, STEP_SIZE, STEP_NUM, STEP_OFFSET); | |
const prediction = await predict(model, predictInput); | |
// re-constructe predicted value based on change | |
const base = airPassagnerData[airPassagnerData.length-1]; | |
const baseDate = moment(new Date(base.Date)); | |
const baseValue = parseInt(base.Number); | |
let predictionValue = []; | |
let val = baseValue; | |
for( let i = 0; i < prediction.length; i+=1) { | |
const item = {}; | |
const date = baseDate.add(1, 'months'); | |
item.time = moment(date).format('YYYY-MM-DD'); | |
item.value = val + val * prediction[i]; | |
item.isPrediction = "Yes"; | |
predictionValue.push(item); | |
val = item.value; | |
} | |
console.log(predictionValue); | |
let airPassagnerDataWithPrediction = []; | |
chartData.forEach(item => { | |
item.isPrediction = "No"; | |
airPassagnerDataWithPrediction.push(item); | |
}) | |
predictionValue.forEach(item => { | |
airPassagnerDataWithPrediction.push(item); | |
}) | |
drawChart(airPassagnerDataWithPrediction); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.4/tf.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.9.2/d3.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/babel-polyfill/7.4.3/polyfill.min.js"></script> | |
<script src="https://gw.alipayobjects.com/os/lib/antv/g2/3.4.10/dist/g2.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.4.0/jquery.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/moment.js/2.24.0/moment.min.js"></script> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/4.1.3/css/bootstrap.min.css" rel="stylesheet" /> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment