Skip to content

Instantly share code, notes, and snippets.

@mikestaub
Last active August 13, 2024 13:31
Show Gist options
  • Save mikestaub/bcaee23838701e33e6d6d80b93f44f16 to your computer and use it in GitHub Desktop.
Save mikestaub/bcaee23838701e33e6d6d80b93f44f16 to your computer and use it in GitHub Desktop.
webllm demo
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Triple LLM Inference Demo</title>
<style>
.container {
display: flex;
flex-direction: column;
align-items: center;
}
.input-area {
width: 80%;
margin-bottom: 20px;
}
.models-container {
display: flex;
flex-wrap: wrap;
justify-content: space-around;
width: 100%;
}
.model {
width: 30%;
margin-bottom: 20px;
}
.hidden {
display: none;
}
textarea, select {
width: 100%;
box-sizing: border-box;
}
.button-container {
display: flex;
justify-content: space-between;
width: 80%;
margin-bottom: 20px;
}
#stop-generation {
background-color: #ff4d4d;
color: white;
padding: 10px 20px;
border: none;
cursor: pointer;
}
#stop-generation:disabled {
background-color: #ffb3b3;
cursor: not-allowed;
}
</style>
</head>
<body>
<div class="container">
<div class="input-area">
<h2>Enter your prompt:</h2>
<textarea id="user-input" rows="5" cols="50"></textarea>
<br>
<div class="button-container">
<button id="initialize-all">Initialize and Download All Models</button>
<button id="generate-all" disabled>Generate for All Models</button>
<button id="stop-generation" disabled>Stop Generation</button>
</div>
</div>
<div class="models-container" id="models-container">
<!-- Model containers will be dynamically inserted here -->
</div>
</div>
<script type="module">
import * as webllm from "https://cdn.jsdelivr.net/npm/@mlc-ai/[email protected]/lib/index.min.js";
const NUMBER_OF_MODELS = 3;
const availableModels = webllm.prebuiltAppConfig.model_list.map(m => m.model_id);
const STORAGE_KEY = 'llm-demo-models';
const DEFAULT_MODELS = [
"Llama-3.1-8B-Instruct-q4f16_1-MLC-1k",
"Phi-3-mini-4k-instruct-q4f16_1-MLC-1k",
"gemma-2-2b-it-q4f16_1-MLC-1k"
];
let models = [];
let isGenerating = false;
let shouldStop = false;
function createModelHTML(modelId) {
return `
<div class="model">
<h2>Model ${modelId}: <span id="model${modelId}-name"></span></h2>
<select id="model-selection${modelId}"></select>
<div id="download-status${modelId}" class="hidden"></div>
<div id="output${modelId}"></div>
<div id="chat-stats${modelId}" class="hidden"></div>
</div>
`;
}
// Create HTML for each model
const modelsContainer = document.getElementById('models-container');
for (let i = 1; i <= NUMBER_OF_MODELS; i++) {
modelsContainer.innerHTML += createModelHTML(i);
}
for (let i = 1; i <= NUMBER_OF_MODELS; i++) {
const selectElement = document.getElementById(`model-selection${i}`);
availableModels.forEach(modelId => {
const option = document.createElement("option");
option.value = modelId;
option.textContent = modelId;
selectElement.appendChild(option);
});
// Set default model
selectElement.value = DEFAULT_MODELS[i - 1] || availableModels[0];
selectElement.addEventListener('change', (e) => saveModelSelection(i, e.target.value));
}
document.getElementById('initialize-all').addEventListener('click', initializeAllModels);
document.getElementById('generate-all').addEventListener('click', generateForAllModels);
document.getElementById('stop-generation').addEventListener('click', stopGeneration);
// Load saved model selections
loadSavedModels();
function saveModelSelection(modelId, selectedModel) {
const savedModels = JSON.parse(localStorage.getItem(STORAGE_KEY) || '{}');
savedModels[modelId] = selectedModel;
localStorage.setItem(STORAGE_KEY, JSON.stringify(savedModels));
}
function loadSavedModels() {
const savedModels = JSON.parse(localStorage.getItem(STORAGE_KEY) || '{}');
for (let i = 1; i <= NUMBER_OF_MODELS; i++) {
const savedModel = savedModels[i] || DEFAULT_MODELS[i - 1] || availableModels[0];
document.getElementById(`model-selection${i}`).value = savedModel;
document.getElementById(`model${i}-name`).textContent = savedModel;
}
}
async function initializeAllModels() {
document.getElementById('initialize-all').disabled = true;
models = [];
const initPromises = [];
for (let i = 1; i <= NUMBER_OF_MODELS; i++) {
const selectedModel = document.getElementById(`model-selection${i}`).value;
const downloadStatus = document.getElementById(`download-status${i}`);
downloadStatus.classList.remove("hidden");
downloadStatus.textContent = "Initializing model...";
const engine = new webllm.MLCEngine();
engine.setInitProgressCallback((report) => {
console.log(`Model ${i} initialize:`, report.progress);
downloadStatus.textContent = report.text;
});
const config = {
temperature: 0.7,
top_p: 0.95,
};
initPromises.push(
engine.reload(selectedModel, config)
.then(() => {
console.log(`Model ${i} initialized`);
downloadStatus.textContent = "Model loaded successfully";
models.push({
id: i,
engine: engine,
messages: [],
loaded: true
});
})
.catch((error) => {
console.error(`Failed to initialize model ${i}:`, error);
downloadStatus.textContent = "Failed to initialize model";
})
);
}
await Promise.all(initPromises);
document.getElementById('initialize-all').disabled = false;
checkAllModelsInitialized();
}
function checkAllModelsInitialized() {
const allInitialized = models.length === NUMBER_OF_MODELS && models.every(model => model.loaded);
document.getElementById('generate-all').disabled = !allInitialized;
}
function stopGeneration() {
shouldStop = true;
document.getElementById('stop-generation').disabled = true;
}
async function generateText(model, input, retryCount = 0) {
const MAX_RETRIES = 3;
const userMessage = { role: "user", content: input };
model.messages.push(userMessage);
const outputDiv = document.getElementById(`output${model.id}`);
outputDiv.textContent = "Generating...";
const startTime = performance.now();
try {
let curMessage = "";
let usage;
const completion = await model.engine.chat.completions.create({
stream: true,
messages: model.messages,
stream_options: { include_usage: true },
});
for await (const chunk of completion) {
if (shouldStop) {
console.log(`Stopping generation for model ${model.id}`);
break;
}
const curDelta = chunk.choices[0]?.delta.content;
if (curDelta) {
curMessage += curDelta;
outputDiv.textContent = curMessage;
}
if (chunk.usage) {
usage = chunk.usage;
}
}
const endTime = performance.now();
const generationTime = (endTime - startTime) / 1000; // Convert to seconds
if (!shouldStop) {
const finalMessage = await model.engine.getMessage();
outputDiv.textContent = finalMessage;
model.messages.push({ role: "assistant", content: finalMessage });
const totalTokens = usage.total_tokens;
const tokensPerSecond = totalTokens / generationTime;
const usageText =
`prompt_tokens: ${usage.prompt_tokens}, ` +
`completion_tokens: ${usage.completion_tokens}, ` +
`total_tokens: ${totalTokens}, ` +
`generation_time: ${generationTime.toFixed(2)} seconds, ` +
`tokens_per_second: ${tokensPerSecond.toFixed(2)}`;
document.getElementById(`chat-stats${model.id}`).classList.remove("hidden");
document.getElementById(`chat-stats${model.id}`).textContent = usageText;
} else {
outputDiv.textContent += " (Generation stopped)";
}
} catch (error) {
console.error(`Error generating text for model ${model.id}:`, error);
if (error.name === "BindingError" && retryCount < MAX_RETRIES) {
console.log(`Retrying generation for model ${model.id} (attempt ${retryCount + 1})`);
if (!shouldStop) {
return generateText(model, input, retryCount + 1);
}
} else {
outputDiv.textContent = `Error: Failed to generate text after ${MAX_RETRIES} attempts. Please try again or use a different model.`;
}
}
}
async function generateForAllModels() {
const input = document.getElementById('user-input').value.trim();
if (input.length === 0) return;
isGenerating = true;
shouldStop = false;
document.getElementById('generate-all').disabled = true;
document.getElementById('stop-generation').disabled = false;
const generatePromises = models.map(model => generateText(model, input));
await Promise.all(generatePromises);
isGenerating = false;
document.getElementById('generate-all').disabled = false;
document.getElementById('stop-generation').disabled = true;
}
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment