Created
June 26, 2024 18:52
-
-
Save nampdn/1ee2201876b4855313a1ca8aa95f5e8f to your computer and use it in GitHub Desktop.
train-neural-network-visualization.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Neural Network Training on Sine Function</title> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/3.7.0/chart.min.js"></script> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 0; padding: 20px; display: flex; flex-direction: column; align-items: center; } | |
#visualization-container { display: flex; justify-content: space-around; width: 100%; margin-bottom: 20px; flex-wrap: wrap; } | |
canvas { border: 1px solid #ccc; margin: 10px; } | |
#controls { margin-top: 20px; } | |
button, input { margin: 5px; } | |
</style> | |
</head> | |
<body> | |
<h1>Neural Network Training on Sine Function</h1> | |
<div id="visualization-container"> | |
<canvas id="networkCanvas" width="800" height="400"></canvas> | |
<canvas id="functionPlot" width="400" height="300"></canvas> | |
<canvas id="errorChart" width="400" height="300"></canvas> | |
</div> | |
<div id="controls"> | |
<button id="startBtn">Start Training</button> | |
<button id="resetBtn">Reset</button> | |
<div> | |
Learning Rate: <span id="learningRateValue">0.01</span> | |
<input type="range" id="learningRateSlider" min="0.001" max="0.1" step="0.001" value="0.01"> | |
</div> | |
</div> | |
<script> | |
const networkCanvas = document.getElementById('networkCanvas'); | |
const networkCtx = networkCanvas.getContext('2d'); | |
const startBtn = document.getElementById('startBtn'); | |
const resetBtn = document.getElementById('resetBtn'); | |
const learningRateSlider = document.getElementById('learningRateSlider'); | |
const learningRateValue = document.getElementById('learningRateValue'); | |
// Neural Network parameters | |
const layers = [1, 10, 10, 1]; // Input, Hidden layers, Output | |
let weights = []; | |
let biases = []; | |
let activations = []; | |
let learningRate = 0.01; | |
let isTraining = false; | |
let epoch = 0; | |
const maxEpochs = 10000; | |
// Initialize weights, biases, and activations | |
function initializeNetwork() { | |
weights = []; | |
biases = []; | |
activations = [new Array(layers[0]).fill(0)]; | |
for (let i = 1; i < layers.length; i++) { | |
const layerWeights = []; | |
const layerBiases = []; | |
for (let j = 0; j < layers[i]; j++) { | |
const neuronWeights = []; | |
for (let k = 0; k < layers[i-1]; k++) { | |
neuronWeights.push(Math.random() - 0.5); | |
} | |
layerWeights.push(neuronWeights); | |
layerBiases.push(Math.random() - 0.5); | |
} | |
weights.push(layerWeights); | |
biases.push(layerBiases); | |
activations.push(new Array(layers[i]).fill(0)); | |
} | |
} | |
// Activation function (tanh) | |
function tanh(x) { | |
return Math.tanh(x); | |
} | |
// Forward pass | |
function forwardPass(input) { | |
activations[0] = [input]; | |
for (let i = 1; i < layers.length; i++) { | |
for (let j = 0; j < layers[i]; j++) { | |
let sum = biases[i-1][j]; | |
for (let k = 0; k < layers[i-1]; k++) { | |
sum += weights[i-1][j][k] * activations[i-1][k]; | |
} | |
activations[i][j] = tanh(sum); | |
} | |
} | |
return activations[activations.length - 1][0]; | |
} | |
// Backpropagation | |
function backpropagate(input, target) { | |
const output = forwardPass(input); | |
const outputError = output - target; | |
// Calculate gradients and update weights/biases | |
const gradients = []; | |
for (let i = layers.length - 1; i > 0; i--) { | |
const layerGradients = []; | |
for (let j = 0; j < layers[i]; j++) { | |
let gradient; | |
if (i === layers.length - 1) { | |
gradient = outputError * (1 - Math.pow(activations[i][j], 2)); | |
} else { | |
gradient = 0; | |
for (let k = 0; k < layers[i+1]; k++) { | |
gradient += gradients[0][k] * weights[i][k][j]; | |
} | |
gradient *= (1 - Math.pow(activations[i][j], 2)); | |
} | |
layerGradients.push(gradient); | |
for (let k = 0; k < layers[i-1]; k++) { | |
weights[i-1][j][k] -= learningRate * gradient * activations[i-1][k]; | |
} | |
biases[i-1][j] -= learningRate * gradient; | |
} | |
gradients.unshift(layerGradients); | |
} | |
return outputError ** 2; | |
} | |
// Training step | |
function trainingStep() { | |
const input = Math.random() * 2 * Math.PI; | |
const target = Math.sin(input); | |
return backpropagate(input, target); | |
} | |
// Draw neural network | |
function drawNetwork() { | |
networkCtx.clearRect(0, 0, networkCanvas.width, networkCanvas.height); | |
const layerWidth = networkCanvas.width / (layers.length + 1); | |
const maxNeurons = Math.max(...layers); | |
for (let i = 0; i < layers.length; i++) { | |
const neurons = layers[i]; | |
for (let j = 0; j < neurons; j++) { | |
const x = (i + 1) * layerWidth; | |
const y = (j + 1) * (networkCanvas.height / (neurons + 1)); | |
// Draw neuron | |
networkCtx.beginPath(); | |
networkCtx.arc(x, y, 15, 0, 2 * Math.PI); | |
networkCtx.fillStyle = 'lightblue'; | |
networkCtx.fill(); | |
networkCtx.stroke(); | |
// Draw activation value | |
networkCtx.fillStyle = 'black'; | |
networkCtx.font = '10px Arial'; | |
networkCtx.textAlign = 'center'; | |
networkCtx.fillText(activations[i][j].toFixed(2), x, y); | |
// Draw bias (except for input layer) | |
if (i > 0) { | |
networkCtx.fillText('b: ' + biases[i-1][j].toFixed(2), x, y + 25); | |
} | |
// Draw connections to next layer | |
if (i < layers.length - 1) { | |
const nextNeurons = layers[i + 1]; | |
for (let k = 0; k < nextNeurons; k++) { | |
const nextX = (i + 2) * layerWidth; | |
const nextY = (k + 1) * (networkCanvas.height / (nextNeurons + 1)); | |
const weight = weights[i][k][j]; | |
const normalizedWeight = (weight + 1) / 2; // Normalize to [0, 1] | |
// Draw connection | |
networkCtx.beginPath(); | |
networkCtx.moveTo(x + 15, y); | |
networkCtx.lineTo(nextX - 15, nextY); | |
networkCtx.strokeStyle = `rgb(${255 * (1 - normalizedWeight)}, 0, ${255 * normalizedWeight})`; | |
networkCtx.lineWidth = Math.abs(weight) * 2; | |
networkCtx.stroke(); | |
// Draw "zap" effect | |
if (isTraining && Math.random() < 0.1) { | |
networkCtx.beginPath(); | |
networkCtx.moveTo(x + 15, y); | |
for (let t = 0.1; t < 1; t += 0.1) { | |
const zapX = x + 15 + (nextX - x - 30) * t; | |
const zapY = y + (nextY - y) * t + (Math.random() - 0.5) * 10; | |
networkCtx.lineTo(zapX, zapY); | |
} | |
networkCtx.lineTo(nextX - 15, nextY); | |
networkCtx.strokeStyle = 'yellow'; | |
networkCtx.lineWidth = 2; | |
networkCtx.stroke(); | |
} | |
} | |
} | |
} | |
} | |
} | |
// Function plot | |
const functionPlot = new Chart(document.getElementById('functionPlot').getContext('2d'), { | |
type: 'scatter', | |
data: { | |
datasets: [{ | |
label: 'Target (sin)', | |
data: [], | |
borderColor: 'rgb(75, 192, 192)', | |
showLine: true, | |
pointRadius: 0 | |
}, { | |
label: 'Network Output', | |
data: [], | |
borderColor: 'rgb(255, 99, 132)', | |
showLine: true, | |
pointRadius: 0 | |
}] | |
}, | |
options: { | |
responsive: true, | |
scales: { | |
x: { | |
type: 'linear', | |
position: 'bottom', | |
min: 0, | |
max: 2 * Math.PI | |
}, | |
y: { | |
min: -1, | |
max: 1 | |
} | |
} | |
} | |
}); | |
// Update function plot | |
function updateFunctionPlot() { | |
const points = 100; | |
const targetData = []; | |
const outputData = []; | |
for (let i = 0; i <= points; i++) { | |
const x = (i / points) * 2 * Math.PI; | |
targetData.push({x: x, y: Math.sin(x)}); | |
outputData.push({x: x, y: forwardPass(x)}); | |
} | |
functionPlot.data.datasets[0].data = targetData; | |
functionPlot.data.datasets[1].data = outputData; | |
functionPlot.update(); | |
} | |
// Error chart | |
const errorChart = new Chart(document.getElementById('errorChart').getContext('2d'), { | |
type: 'line', | |
data: { | |
labels: [], | |
datasets: [{ | |
label: 'Mean Squared Error', | |
data: [], | |
borderColor: 'rgb(75, 192, 192)', | |
tension: 0.1 | |
}] | |
}, | |
options: { | |
responsive: true, | |
scales: { | |
y: { | |
beginAtZero: true | |
} | |
}, | |
animation: { | |
duration: 0 | |
} | |
} | |
}); | |
// Update error chart | |
function updateErrorChart(error) { | |
errorChart.data.labels.push(epoch); | |
errorChart.data.datasets[0].data.push(error); | |
if (errorChart.data.labels.length > 100) { | |
errorChart.data.labels.shift(); | |
errorChart.data.datasets[0].data.shift(); | |
} | |
errorChart.update(); | |
} | |
// Training loop | |
function train() { | |
if (isTraining && epoch < maxEpochs) { | |
const error = trainingStep(); | |
if (epoch % 10 === 0) { | |
drawNetwork(); | |
updateFunctionPlot(); | |
updateErrorChart(error); | |
} | |
epoch++; | |
requestAnimationFrame(train); | |
} else { | |
isTraining = false; | |
startBtn.textContent = 'Start Training'; | |
} | |
} | |
// Reset network | |
function resetNetwork() { | |
initializeNetwork(); | |
epoch = 0; | |
errorChart.data.labels = []; | |
errorChart.data.datasets[0].data = []; | |
errorChart.update(); | |
drawNetwork(); | |
updateFunctionPlot(); | |
} | |
// Event listeners | |
startBtn.addEventListener('click', () => { | |
isTraining = !isTraining; | |
startBtn.textContent = isTraining ? 'Pause Training' : 'Resume Training'; | |
if (isTraining) train(); | |
}); | |
resetBtn.addEventListener('click', resetNetwork); | |
learningRateSlider.addEventListener('input', (e) => { | |
learningRate = parseFloat(e.target.value); | |
learningRateValue.textContent = learningRate.toFixed(3); | |
}); | |
// Initialize | |
initializeNetwork(); | |
drawNetwork(); | |
updateFunctionPlot(); | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment