Created
February 13, 2023 08:01
-
-
Save Hokid/06452b355a6b88a758a6c096df34ab5d to your computer and use it in GitHub Desktop.
simple-code-search-engine-using-open-ai-api
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const path = require('path'); | |
const ts = require('typescript'); | |
const csv = require('csv-stringify/sync'); | |
const cwd = process.cwd(); | |
const configJSON = require(path.join(cwd, 'tsconfig.json')); | |
const config = ts.parseJsonConfigFileContent(configJSON, ts.sys, cwd); | |
const program = ts.createProgram( | |
config.fileNames, | |
config.options, | |
ts.createCompilerHost(config.options) | |
); | |
const checker = program.getTypeChecker(); | |
const rows = []; | |
const addRow = (fileName, name, code, docs = '') => rows.push({ | |
file_name: path.relative(cwd, fileName), | |
name, | |
code, | |
docs | |
}); | |
function addFunction(fileName, node) { | |
const symbol = checker.getSymbolAtLocation(node.name); | |
if (symbol) { | |
const name = symbol.getName(); | |
const docs = getDocs(symbol); | |
const code = node.getText(); | |
addRow(fileName, name, code, docs); | |
} | |
} | |
function addClass(fileName, node) { | |
const symbol = checker.getSymbolAtLocation(node.name); | |
if (symbol) { | |
const name = symbol.getName(); | |
const docs = getDocs(symbol); | |
const code = `class ${name} {}`; | |
addRow(fileName, name, code, docs); | |
node.members.forEach(m => addClassMember(fileName, name, m)); | |
} | |
} | |
function addClassMember(fileName, className, node) { | |
const symbol = checker.getSymbolAtLocation(node.name); | |
if (symbol) { | |
const name = className + ':' + symbol.getName(); | |
const docs = getDocs(symbol); | |
const code = node.getText(); | |
addRow(fileName, name, code, docs); | |
} | |
} | |
function addInterface(fileName, node) { | |
const symbol = checker.getSymbolAtLocation(node.name); | |
if (symbol) { | |
const name = symbol.getName(); | |
const docs = getDocs(symbol); | |
const code = `interface ${name} {}`; | |
addRow(fileName, name, code, docs); | |
node.members.forEach(m => addInterfaceMember(fileName, name, m)); | |
} | |
} | |
function addInterfaceMember(fileName, interfaceName, node) { | |
if (!ts.isPropertySignature(node) || !ts.isMethodSignature(node)) { | |
return; | |
} | |
const symbol = checker.getSymbolAtLocation(node.name); | |
if (symbol) { | |
const name = interfaceName + ':' + symbol.getName(); | |
const docs = getDocs(symbol); | |
const code = node.getText(); | |
addRow(fileName, name, code, docs); | |
} | |
} | |
function getDocs(symbol) { | |
return ts.displayPartsToString(symbol.getDocumentationComment(checker)); | |
} | |
for (const fileName of config.fileNames) { | |
const sourceFile = program.getSourceFile(fileName); | |
const visitNode = node => { | |
if (ts.isFunctionDeclaration(node)) { | |
addFunction(fileName, node); | |
} else if (ts.isClassDeclaration(node)) { | |
addClass(fileName, node); | |
} else if (ts.isInterfaceDeclaration(node)) { | |
addInterface(fileName, node); | |
} | |
ts.forEachChild(node, visitNode); | |
}; | |
ts.forEachChild(sourceFile, visitNode); | |
} | |
for (const row of rows) { | |
row.combined = ''; | |
if (row.docs) { | |
row.combined += `Code documentation: ${row.docs}; `; | |
} | |
row.combined += `Code: ${row.code}; Name: ${row.name};`; | |
} | |
const output = csv.stringify(rows, { | |
header: true | |
}); | |
console.log(output); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from io import StringIO | |
from subprocess import PIPE, run | |
from pandas import read_csv | |
from openai.embeddings_utils import get_embedding as _get_embedding | |
from tenacity import wait_random_exponential, stop_after_attempt | |
get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10)) | |
if __name__ == '__main__': | |
# 1 | |
result = run(['node', 'code-to-csv.js'], stdout=PIPE, stderr=PIPE, universal_newlines=True) | |
if result.returncode != 0: | |
raise RuntimeError(result.stderr) | |
# 2 | |
db = read_csv(StringIO(result.stdout)) | |
# 3 | |
db['embedding'] = db['combined'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002')) | |
# 4 | |
db.to_csv("search_db.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
from pandas import read_csv | |
from openai.embeddings_utils import cosine_similarity, get_embedding as _get_embedding | |
from tenacity import stop_after_attempt, wait_random_exponential | |
get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10)) | |
def search(db, query): | |
# 4 | |
query_embedding = get_embedding(query, engine='text-embedding-ada-002') | |
# 5 | |
db['similarities'] = db.embedding.apply(lambda x: cosine_similarity(x, query_embedding)) | |
# 6 | |
db.sort_values('similarities', ascending=False, inplace=True) | |
result = db.head(3) | |
text = "" | |
for row in result.itertuples(index=False): | |
score=round(row.similarities, 3) | |
if type(row.docs) == str: | |
text += '/**\n * {docs}\n */\n'.format(docs='\n * '.join(row.docs.split('\n'))) | |
text += '{code}\n\n'.format(code='\n'.join(row.code.split('\n')[:7])) | |
text += '[score={score}] {file_name}:{name}\n'.format(score=score, file_name=row.file_name, name=row.name) | |
text += '-' * 70 + '\n\n' | |
return text | |
if __name__ == '__main__': | |
# 1 | |
db = read_csv('search_db.csv') | |
# 2 | |
db['embedding'] = db.embedding.apply(eval).apply(np.array) | |
query = sys.argv[1] | |
print('') | |
# 3 | |
print(search(db, query)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment