Last active
August 13, 2024 13:31
-
-
Save mikestaub/bcaee23838701e33e6d6d80b93f44f16 to your computer and use it in GitHub Desktop.
webllm demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Triple LLM Inference Demo</title> | |
<style> | |
.container { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
} | |
.input-area { | |
width: 80%; | |
margin-bottom: 20px; | |
} | |
.models-container { | |
display: flex; | |
flex-wrap: wrap; | |
justify-content: space-around; | |
width: 100%; | |
} | |
.model { | |
width: 30%; | |
margin-bottom: 20px; | |
} | |
.hidden { | |
display: none; | |
} | |
textarea, select { | |
width: 100%; | |
box-sizing: border-box; | |
} | |
.button-container { | |
display: flex; | |
justify-content: space-between; | |
width: 80%; | |
margin-bottom: 20px; | |
} | |
#stop-generation { | |
background-color: #ff4d4d; | |
color: white; | |
padding: 10px 20px; | |
border: none; | |
cursor: pointer; | |
} | |
#stop-generation:disabled { | |
background-color: #ffb3b3; | |
cursor: not-allowed; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="input-area"> | |
<h2>Enter your prompt:</h2> | |
<textarea id="user-input" rows="5" cols="50"></textarea> | |
<br> | |
<div class="button-container"> | |
<button id="initialize-all">Initialize and Download All Models</button> | |
<button id="generate-all" disabled>Generate for All Models</button> | |
<button id="stop-generation" disabled>Stop Generation</button> | |
</div> | |
</div> | |
<div class="models-container" id="models-container"> | |
<!-- Model containers will be dynamically inserted here --> | |
</div> | |
</div> | |
<script type="module"> | |
import * as webllm from "https://cdn.jsdelivr.net/npm/@mlc-ai/[email protected]/lib/index.min.js"; | |
const NUMBER_OF_MODELS = 3; | |
const availableModels = webllm.prebuiltAppConfig.model_list.map(m => m.model_id); | |
const STORAGE_KEY = 'llm-demo-models'; | |
const DEFAULT_MODELS = [ | |
"Llama-3.1-8B-Instruct-q4f16_1-MLC-1k", | |
"Phi-3-mini-4k-instruct-q4f16_1-MLC-1k", | |
"gemma-2-2b-it-q4f16_1-MLC-1k" | |
]; | |
let models = []; | |
let isGenerating = false; | |
let shouldStop = false; | |
function createModelHTML(modelId) { | |
return ` | |
<div class="model"> | |
<h2>Model ${modelId}: <span id="model${modelId}-name"></span></h2> | |
<select id="model-selection${modelId}"></select> | |
<div id="download-status${modelId}" class="hidden"></div> | |
<div id="output${modelId}"></div> | |
<div id="chat-stats${modelId}" class="hidden"></div> | |
</div> | |
`; | |
} | |
// Create HTML for each model | |
const modelsContainer = document.getElementById('models-container'); | |
for (let i = 1; i <= NUMBER_OF_MODELS; i++) { | |
modelsContainer.innerHTML += createModelHTML(i); | |
} | |
for (let i = 1; i <= NUMBER_OF_MODELS; i++) { | |
const selectElement = document.getElementById(`model-selection${i}`); | |
availableModels.forEach(modelId => { | |
const option = document.createElement("option"); | |
option.value = modelId; | |
option.textContent = modelId; | |
selectElement.appendChild(option); | |
}); | |
// Set default model | |
selectElement.value = DEFAULT_MODELS[i - 1] || availableModels[0]; | |
selectElement.addEventListener('change', (e) => saveModelSelection(i, e.target.value)); | |
} | |
document.getElementById('initialize-all').addEventListener('click', initializeAllModels); | |
document.getElementById('generate-all').addEventListener('click', generateForAllModels); | |
document.getElementById('stop-generation').addEventListener('click', stopGeneration); | |
// Load saved model selections | |
loadSavedModels(); | |
function saveModelSelection(modelId, selectedModel) { | |
const savedModels = JSON.parse(localStorage.getItem(STORAGE_KEY) || '{}'); | |
savedModels[modelId] = selectedModel; | |
localStorage.setItem(STORAGE_KEY, JSON.stringify(savedModels)); | |
} | |
function loadSavedModels() { | |
const savedModels = JSON.parse(localStorage.getItem(STORAGE_KEY) || '{}'); | |
for (let i = 1; i <= NUMBER_OF_MODELS; i++) { | |
const savedModel = savedModels[i] || DEFAULT_MODELS[i - 1] || availableModels[0]; | |
document.getElementById(`model-selection${i}`).value = savedModel; | |
document.getElementById(`model${i}-name`).textContent = savedModel; | |
} | |
} | |
async function initializeAllModels() { | |
document.getElementById('initialize-all').disabled = true; | |
models = []; | |
const initPromises = []; | |
for (let i = 1; i <= NUMBER_OF_MODELS; i++) { | |
const selectedModel = document.getElementById(`model-selection${i}`).value; | |
const downloadStatus = document.getElementById(`download-status${i}`); | |
downloadStatus.classList.remove("hidden"); | |
downloadStatus.textContent = "Initializing model..."; | |
const engine = new webllm.MLCEngine(); | |
engine.setInitProgressCallback((report) => { | |
console.log(`Model ${i} initialize:`, report.progress); | |
downloadStatus.textContent = report.text; | |
}); | |
const config = { | |
temperature: 0.7, | |
top_p: 0.95, | |
}; | |
initPromises.push( | |
engine.reload(selectedModel, config) | |
.then(() => { | |
console.log(`Model ${i} initialized`); | |
downloadStatus.textContent = "Model loaded successfully"; | |
models.push({ | |
id: i, | |
engine: engine, | |
messages: [], | |
loaded: true | |
}); | |
}) | |
.catch((error) => { | |
console.error(`Failed to initialize model ${i}:`, error); | |
downloadStatus.textContent = "Failed to initialize model"; | |
}) | |
); | |
} | |
await Promise.all(initPromises); | |
document.getElementById('initialize-all').disabled = false; | |
checkAllModelsInitialized(); | |
} | |
function checkAllModelsInitialized() { | |
const allInitialized = models.length === NUMBER_OF_MODELS && models.every(model => model.loaded); | |
document.getElementById('generate-all').disabled = !allInitialized; | |
} | |
function stopGeneration() { | |
shouldStop = true; | |
document.getElementById('stop-generation').disabled = true; | |
} | |
async function generateText(model, input, retryCount = 0) { | |
const MAX_RETRIES = 3; | |
const userMessage = { role: "user", content: input }; | |
model.messages.push(userMessage); | |
const outputDiv = document.getElementById(`output${model.id}`); | |
outputDiv.textContent = "Generating..."; | |
const startTime = performance.now(); | |
try { | |
let curMessage = ""; | |
let usage; | |
const completion = await model.engine.chat.completions.create({ | |
stream: true, | |
messages: model.messages, | |
stream_options: { include_usage: true }, | |
}); | |
for await (const chunk of completion) { | |
if (shouldStop) { | |
console.log(`Stopping generation for model ${model.id}`); | |
break; | |
} | |
const curDelta = chunk.choices[0]?.delta.content; | |
if (curDelta) { | |
curMessage += curDelta; | |
outputDiv.textContent = curMessage; | |
} | |
if (chunk.usage) { | |
usage = chunk.usage; | |
} | |
} | |
const endTime = performance.now(); | |
const generationTime = (endTime - startTime) / 1000; // Convert to seconds | |
if (!shouldStop) { | |
const finalMessage = await model.engine.getMessage(); | |
outputDiv.textContent = finalMessage; | |
model.messages.push({ role: "assistant", content: finalMessage }); | |
const totalTokens = usage.total_tokens; | |
const tokensPerSecond = totalTokens / generationTime; | |
const usageText = | |
`prompt_tokens: ${usage.prompt_tokens}, ` + | |
`completion_tokens: ${usage.completion_tokens}, ` + | |
`total_tokens: ${totalTokens}, ` + | |
`generation_time: ${generationTime.toFixed(2)} seconds, ` + | |
`tokens_per_second: ${tokensPerSecond.toFixed(2)}`; | |
document.getElementById(`chat-stats${model.id}`).classList.remove("hidden"); | |
document.getElementById(`chat-stats${model.id}`).textContent = usageText; | |
} else { | |
outputDiv.textContent += " (Generation stopped)"; | |
} | |
} catch (error) { | |
console.error(`Error generating text for model ${model.id}:`, error); | |
if (error.name === "BindingError" && retryCount < MAX_RETRIES) { | |
console.log(`Retrying generation for model ${model.id} (attempt ${retryCount + 1})`); | |
if (!shouldStop) { | |
return generateText(model, input, retryCount + 1); | |
} | |
} else { | |
outputDiv.textContent = `Error: Failed to generate text after ${MAX_RETRIES} attempts. Please try again or use a different model.`; | |
} | |
} | |
} | |
async function generateForAllModels() { | |
const input = document.getElementById('user-input').value.trim(); | |
if (input.length === 0) return; | |
isGenerating = true; | |
shouldStop = false; | |
document.getElementById('generate-all').disabled = true; | |
document.getElementById('stop-generation').disabled = false; | |
const generatePromises = models.map(model => generateText(model, input)); | |
await Promise.all(generatePromises); | |
isGenerating = false; | |
document.getElementById('generate-all').disabled = false; | |
document.getElementById('stop-generation').disabled = true; | |
} | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment