Created
July 18, 2023 08:05
-
-
Save hmmhmmhm/fe42a1cb0c99b842bcdd83109e8c425a to your computer and use it in GitHub Desktop.
Fine Tuning Open AI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import fs from "fs"; | |
import * as tokenizer from "gpt-3-encoder"; | |
import { Configuration, OpenAIApi } from "openai"; | |
import { logger } from "./utils/logger.js"; | |
import jsonl from "jsonl"; | |
import inquirer from "inquirer"; | |
import chalk from "chalk"; | |
export const loadProjectDataFromJSON = ( | |
assetPath: string | |
): Record<string, string> => { | |
const json = fs.readFileSync(assetPath, "utf8"); | |
const data = JSON.parse(json); | |
return data; | |
}; | |
export const tokenSplitter = (code: string, limitToken: number) => { | |
const token = tokenizer.encode(code); | |
const collected: string[] = []; | |
let current: number[] = []; | |
for (const t of token) { | |
if (current.length + 1 > limitToken) { | |
collected.push(tokenizer.decode(current)); | |
current = []; | |
} | |
current.push(t); | |
} | |
if (current.length > 0) { | |
collected.push(tokenizer.decode(current)); | |
} | |
return collected; | |
}; | |
export const convertToTrainingData = ( | |
data: Record<string, string>, | |
tokenSplitCount: number | |
) => { | |
const prompts: { | |
prompt: string; | |
completion: string; | |
}[] = []; | |
const timestamp = new Date().toISOString(); | |
for (const [key, value] of Object.entries(data)) { | |
const completionToken = tokenizer.encode(value); | |
if (completionToken.length <= tokenSplitCount) { | |
const oneLinePrompt = { | |
prompt: `{ | |
title: "friday-gpt 프로젝트 파일", | |
filePath: ${key}, | |
lastModified: ${timestamp} | |
}`, | |
completion: `${value}`, | |
}; | |
prompts.push(oneLinePrompt); | |
} else { | |
const splitted = tokenSplitter(value, tokenSplitCount); | |
let i = 1; | |
for (const s of splitted) { | |
prompts.push({ | |
prompt: `{ | |
title: "friday-gpt 프로젝트 파일 (분할됨: ${i}/${splitted.length})", | |
filePath: ${key}, | |
lastModified: ${timestamp} | |
}`, | |
completion: s, | |
}); | |
i += 1; | |
} | |
} | |
} | |
return prompts; | |
}; | |
export const uploadTrainingData = async (assetPath: string) => { | |
const configuration = new Configuration({ | |
apiKey: process.env.OPENAI_API_KEY, | |
}); | |
const openai = new OpenAIApi(configuration); | |
// * https://github.com/openai/openai-node/issues/25#issuecomment-1291536117 | |
const file = fs.createReadStream(assetPath) as any; | |
const response = await openai.createFile(file, "fine-tune"); | |
return response.data.id; | |
}; | |
export const deleteTrainingData = async (fileId: string) => { | |
const configuration = new Configuration({ | |
apiKey: process.env.OPENAI_API_KEY, | |
}); | |
const openai = new OpenAIApi(configuration); | |
const response = await openai.deleteFile(fileId); | |
return response.data.deleted; | |
}; | |
export const createFineTuneModel = async ({ | |
traningDataId, | |
epochs, | |
}: { | |
traningDataId: string; | |
epochs: number; | |
}) => { | |
const configuration = new Configuration({ | |
apiKey: process.env.OPENAI_API_KEY, | |
}); | |
const openai = new OpenAIApi(configuration); | |
const response = await openai.createFineTune({ | |
training_file: traningDataId, | |
model: "davinci", | |
suffix: "friday-gpt", | |
n_epochs: epochs, | |
}); | |
const events = response?.data?.events; | |
return { | |
events, | |
fineTuneId: response.data.id, | |
}; | |
}; | |
export const getStatusOfFineTuning = async (fineTuneId: string) => { | |
const configuration = new Configuration({ | |
apiKey: process.env.OPENAI_API_KEY, | |
}); | |
const openai = new OpenAIApi(configuration); | |
const response = await openai.listFineTuneEvents(fineTuneId); | |
return response.data.data; | |
}; | |
export const now = () => { | |
return new Date().toISOString().split("T")[1].split(".")[0]; | |
}; | |
export const main = async () => { | |
logger(`[${now()}] 파인튜닝할 프로젝트 데이터 로딩중...`); | |
const data = loadProjectDataFromJSON("./training/collected.json"); | |
// * https://platform.openai.com/docs/models/gpt-3 | |
const prompts = convertToTrainingData(data, 1500); | |
// * 비용예측 | |
// * https://openai.com/pricing | |
fs.writeFileSync("./training/training.json", JSON.stringify(prompts)); | |
await new Promise<void>((resolve) => { | |
fs.createReadStream("./training/training.json") | |
.pipe(jsonl()) | |
.pipe(fs.createWriteStream("./training/training.jsonl")) | |
.on("finish", () => { | |
resolve(); | |
}); | |
}); | |
// * 몇번 반복해서 학습할지 묻기 (기본값:1 이나 4번 학습이 권장됨) | |
// const repeatCount = await prompt( | |
// "몇번 반복해서 학습할까요? (epochs 기본값: 1): " | |
// ); | |
// * readline 이용해서 묻기 | |
const { repeatCount } = await inquirer.prompt({ | |
name: "repeatCount", | |
message: chalk.magentaBright( | |
`몇번 반복해서 학습할까요? (epochs 기본값: 1): ` | |
), | |
}); | |
const epochs = Number(repeatCount) || 1; | |
const token = tokenizer.encode(JSON.stringify(prompts)); | |
const trainigPrice = (token.length / 1000) * 0.03 * epochs; | |
logger( | |
`[${now()}] 프로젝트 전체 트레이닝 비용: $${trainigPrice} (${ | |
token.length | |
} 토큰)` | |
); | |
// * 진행할지 물어보기 | |
const { answer } = await inquirer.prompt({ | |
name: "answer", | |
message: chalk.magentaBright(`진행하시겠습니까? (y/n): `), | |
}); | |
if (answer.toLowerCase() !== "y") { | |
logger(`[${now()}] 종료합니다.`); | |
return; | |
} | |
// * 트레이닝 데이터 업로드 | |
const traningDataId = await uploadTrainingData("./training/training.jsonl"); | |
logger(`[${now()}] 트레이닝 데이터 업로드 완료: ${traningDataId}`); | |
// * 파인튜닝 모델 생성 | |
const { fineTuneId } = await createFineTuneModel({ traningDataId, epochs }); | |
logger(`[${now()}] 파인튜닝 모델 학습이 시작되었습니다: ${fineTuneId}`); | |
const startFineTuningTime = new Date().getTime(); | |
// * 파인튜닝 모델 완성 대기 | |
let status = await getStatusOfFineTuning(fineTuneId); | |
const checkedMessageCreatedAt = new Set<number>(); | |
let isFinished = false; | |
while (!isFinished) { | |
for (const event of status) { | |
if (!checkedMessageCreatedAt.has(event.created_at)) { | |
logger(`[${now()}] [Open A.I]: ${event.message}`); | |
checkedMessageCreatedAt.add(event.created_at); | |
} | |
if ( | |
event.message.includes("cancelled") || | |
event.message.includes("failed") | |
) { | |
logger(`[${now()}] 파인튜닝 모델 학습이 완료되었습니다: ${fineTuneId}`); | |
isFinished = true; | |
break; | |
} | |
if (event.message.startsWith("Uploaded model: ")) { | |
const modelId = event.message.split("Uploaded model: ")[1]; | |
fs.writeFileSync("./training/modelId.txt", modelId); | |
logger(`[${now()}] 파인튜닝 모델 학습이 완료되었습니다: ${modelId}`); | |
isFinished = true; | |
break; | |
} | |
} | |
status = await getStatusOfFineTuning(fineTuneId); | |
await new Promise((resolve) => setTimeout(resolve, 2000)); | |
} | |
const endFineTuningTime = new Date().getTime(); | |
const fineTuningTime = (endFineTuningTime - startFineTuningTime) / 1000; | |
logger(`[${now()}] ${fineTuningTime}초 만에 모델 학습이 완료되었습니다.`); | |
// * 트레이닝 데이터 삭제 | |
await deleteTrainingData(traningDataId); | |
logger(`[${now()}] 트레이닝 데이터 삭제 완료: ${traningDataId}`); | |
}; | |
main(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment