Created
February 19, 2025 21:28
-
-
Save angrypie/88dddc84435d46c0f52e7c5470fa1b4f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { GoogleGenerativeAI } from "@google/generative-ai"; | |
import { Mistral } from "@mistralai/mistralai"; | |
import * as diff from "diff"; | |
// const Groq = require('groq-sdk'); | |
// const Groq = require('groq-sdk'); | |
import { Groq } from "groq-sdk"; | |
async function getPrediction(provider: "mistral" | "gemini" | 'groq', prompt: string): Promise<string> { | |
const startTime = performance.now(); | |
//fo not use split | |
const firstLne = langFromFileName(prompt.substring(0, prompt.indexOf('\n'))); | |
let prediction = '<NOT IMPLEMENTED>' | |
const systemMessage = `Based on diff predict next changes or fix code. Respond only with code, no explanation, no formatting.` | |
const messages: {content: string, role: "system" | "user"}[] = [ | |
{ content: systemMessage, role: "system", }, | |
{content: "do not edit code before <|editable_region_start|>", role: "user", }, | |
{content: "do not edit code past <|editable_region_end|>", role: "user", }, | |
{content: "fix code only inside <|editable_region_end|> and <|editable_region_start|>", role: "user", }, | |
{content: `Language: ${firstLne}`, role: "user", }, | |
{ content: prompt, role: "user", }, | |
] | |
switch (provider) { | |
case "mistral":{ | |
const mistral = new Mistral({ apiKey: process.env["MISTRAL_API_KEY"] ?? "", }); | |
const result = await mistral.chat.complete({ | |
model: "codestral-latest", //22b model | |
// model: "codestral-mamba-latest", //8b model | |
// model: "ministral-3b-latest", | |
// model: "ministral-8b-latest", | |
stream: false, | |
maxTokens: 500, | |
temperature: 0, | |
messages: messages, | |
prediction: { | |
type: 'content', | |
content: prompt, | |
} | |
}); | |
console.log(">> Usage", result.usage) | |
const text = result.choices?.[0].message.content | |
if (typeof text !== "string") { | |
throw new Error("No text"); | |
} | |
prediction = text; | |
break | |
} | |
case "gemini":{ | |
const genAI = new GoogleGenerativeAI(process.env["GEMINI_API_KEY"] ?? ""); | |
const model = genAI.getGenerativeModel({ model: "gemini-2.0-flash", generationConfig: { | |
maxOutputTokens: 300, | |
temperature: 0, | |
} }); | |
// const model = genAI.getGenerativeModel({ model: "gemini-2.0-flash-lite-preview-02-05" }); | |
const result = await model.generateContent([systemMessage, prompt], | |
); | |
prediction = result.response.text(); | |
break | |
} | |
case "groq": { | |
const key = process.env["GROQ_API_KEY"] ?? ""; | |
if (!key) { | |
throw new Error("No GROQ_API_KEY"); | |
} | |
const groq = new Groq({ apiKey: key }); | |
const chatCompletion = await groq.chat.completions.create({ | |
"messages": [ | |
{ content: systemMessage, role: "system", }, | |
{content: "do not edit code before <|editable_region_start|>", role: "user", }, | |
{content: "do not edit code past <|editable_region_end|>", role: "user", }, | |
{content: "fix code only inside <|editable_region_end|> and <|editable_region_start|>", role: "user", }, | |
{ content: prompt, role: "user", }, | |
], | |
"model": "qwen-2.5-coder-32b", | |
// "model": "deepseek-r1-distill-qwen-32b", | |
// "model": "deepseek-r1-distill-llama-70b", | |
// "model": "mixtral-8x7b-32768", | |
// "model": "llama-3.1-8b-instant", | |
"temperature": 0, | |
"max_completion_tokens": 4096 / 4, | |
// "top_p": 0.95, | |
// "stream": true, | |
"stop": null | |
}); | |
// for await (const chunk of chatCompletion) { | |
// process.stdout.write(chunk.choices[0]?.delta?.content || ''); | |
// } | |
prediction = chatCompletion.choices[0]?.message.content || "EMPTY" | |
break; | |
} | |
} | |
const endTime = performance.now(); | |
const executionTime = Math.round(endTime - startTime); | |
console.log(`Exec time: ${executionTime.toFixed(2)}ms`); | |
return prediction; | |
} | |
const user_cusor_is_here = "<|user_cursor_is_here|>"; | |
const editable_region_start = "<|editable_region_start|>"; | |
const editable_region_end = "<|editable_region_end|>"; | |
function greenText(text: string) { | |
return `\x1b[32m${text}\x1b[0m`; | |
} | |
function redText(text: string) { | |
return `\x1b[31m${text}\x1b[0m`; | |
} | |
function strikeThrough(text: string) { | |
return `\x1b[9m${text}\x1b[0m`; | |
} | |
function printDiff(prompt: string, prompt2: string) { | |
const d = diff.diffWordsWithSpace(prompt, prompt2); | |
const result = [] as string[]; | |
for (const part of d) { | |
const { added, removed, value} = part | |
if(!added && !removed) { | |
result.push(value) | |
continue | |
} | |
if (part.added && !removed) { | |
result.push(greenText(value)) | |
} | |
if (part.removed && !added) { | |
result.push(redText(strikeThrough(value))) | |
} | |
if (part.removed && part.added) { | |
result.push("*") | |
result.push(value) | |
result.push("*") | |
} | |
} | |
console.log(result.join("")) | |
console.log("End of diff >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"); | |
} | |
//makes big changes with "fix code" | |
// const systemMessage = `fix typescript code. do not explain final anwer. edit code only inside <|editable_region_end|> and <|editable_region_start|>.` | |
// makes small changes with "predic what user wants to cange..." | |
// const prompt = `type ProviderLLM = { | |
// getMdels: () => Promise<Model[]>; | |
// generateeText: (model: Model, prompt: string) => Promise<string>; | |
// <|user_cursor_is_here|> | |
// }` | |
const data = getData() | |
const prompt = `Last edits:${data.events}\nCode in javascript:\n${data.output}` | |
async function improvePrompt(prompt: string, rounds: number) { | |
for (let k = 0; k < rounds; k++) { | |
const prediction = await getPrediction("mistral", prompt); | |
// console.log(prediction) | |
// console.log(zeta(prediction).editableRegion().get()) | |
// console.log(format(prompt).editableRegion().get()) | |
printDiff( | |
format(prompt).editableRegion().get(), | |
format(prediction).editableRegion().get() | |
); | |
const lines = prediction.split("\n"); | |
// lines.push("<|user_cursor_is_here|>"); | |
const withCursor = lines.join("\n"); | |
console.log("new gen=======================", k); | |
prompt = withCursor; | |
} | |
} | |
await improvePrompt(prompt, 1); | |
function getData() { | |
const entry = { | |
output: '', | |
events: '' | |
} | |
entry.events = `User edited "src/components/modals/ZetaReview.tsx": | |
\`\`\`diff | |
@@ -209,7 +209,7 @@ | |
fileName: "src/utils/helpers.ts", | |
recency: "5 hours ago", | |
type: "file", | |
- latency: Math.floor(Math.random() * 1000), | |
+ latency: getRandomNumber(), | |
}, | |
{ | |
id: "6", | |
\`\`\`src/components/modals/ZetaReview.tsx` | |
entry.output = ` | |
recency: "3 hours ago", | |
type: "diff", | |
latency: getRandomNumber(), | |
}, | |
{ | |
<|editable_region_start|> | |
id: "4", | |
fileName: "src/components/Dropdown.tsx", | |
recency: "4 hours ago", | |
type: "diff", | |
latency: getRandomNumber(), | |
}, | |
{ | |
id: "5", | |
fileName: "src/utils/helpers.ts", | |
recency: "5 hours ago", | |
type: "file", | |
latency: getRandomNumber(<|user_cursor_is_here|>), | |
}, | |
{ | |
id: "6", | |
fileName: "src/styles/globals.css", | |
recency: "6 hours ago", | |
type: "file", | |
latency: Math.floor(Math.random() * 1000), | |
}, | |
{ | |
id: "7", | |
fileName: "src/components/Modal.tsx", | |
recency: "7 hours ago", | |
type: "diff", | |
latency: Math.floor(Math.random() * 1000), | |
<|editable_region_end|> | |
}, | |
{ | |
id: "8",` | |
return entry; | |
} | |
function langFromFileName(path: string ) { | |
const index = path.lastIndexOf("."); | |
if (index === -1) { | |
throw new Error("No file extension found"); | |
} | |
const extension = path.slice(index + 1) | |
.replace(/"|'|:/g, "") | |
switch (extension) { | |
case "ts": | |
return "typescript"; | |
case "js": | |
return "javascript"; | |
case "jsx": | |
return "javascript react"; | |
case "tsx": | |
return "typescript react"; | |
case "py": | |
return "python"; | |
case "lua": | |
return "lua"; | |
case "rs": | |
return "rust"; | |
case "json": | |
return "json"; | |
case "html": | |
return "html"; | |
case "css": | |
return "css"; | |
case "toml": | |
return "toml"; | |
default: | |
throw new Error("Unknown file extension:" + extension + "<"); | |
} | |
} | |
function format(text: string) { | |
const editableRegion = () => { | |
//editable meta tadgs hsould have \n a the and and before | |
// .._start.length + 1 - to count for \n start tag | |
// endIndex-1 - to count for \n before editable end tag | |
const startIndex = text.indexOf(editable_region_start+"\n"); | |
const endIndex = text.indexOf(editable_region_end+"\n"); | |
const cut = text.substring( | |
startIndex === -1 ? 0 : startIndex + editable_region_start.length + 1, | |
endIndex === -1 ? text.length : endIndex-1, | |
); | |
const result = cut.replace(user_cusor_is_here, ""); | |
return format(result); | |
}; | |
return { | |
user_cusor_is_here, | |
editable_region_start, | |
editable_region_end, | |
editableRegion, | |
get() { | |
return text; | |
}, | |
getLines() { | |
return text.split("\n"); | |
}, | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment