Skip to content

Instantly share code, notes, and snippets.

@Hokid
Created February 13, 2023 08:01
Show Gist options
  • Save Hokid/06452b355a6b88a758a6c096df34ab5d to your computer and use it in GitHub Desktop.
Save Hokid/06452b355a6b88a758a6c096df34ab5d to your computer and use it in GitHub Desktop.
simple-code-search-engine-using-open-ai-api
const path = require('path');
const ts = require('typescript');
const csv = require('csv-stringify/sync');
const cwd = process.cwd();
const configJSON = require(path.join(cwd, 'tsconfig.json'));
const config = ts.parseJsonConfigFileContent(configJSON, ts.sys, cwd);
const program = ts.createProgram(
config.fileNames,
config.options,
ts.createCompilerHost(config.options)
);
const checker = program.getTypeChecker();
const rows = [];
const addRow = (fileName, name, code, docs = '') => rows.push({
file_name: path.relative(cwd, fileName),
name,
code,
docs
});
function addFunction(fileName, node) {
const symbol = checker.getSymbolAtLocation(node.name);
if (symbol) {
const name = symbol.getName();
const docs = getDocs(symbol);
const code = node.getText();
addRow(fileName, name, code, docs);
}
}
function addClass(fileName, node) {
const symbol = checker.getSymbolAtLocation(node.name);
if (symbol) {
const name = symbol.getName();
const docs = getDocs(symbol);
const code = `class ${name} {}`;
addRow(fileName, name, code, docs);
node.members.forEach(m => addClassMember(fileName, name, m));
}
}
function addClassMember(fileName, className, node) {
const symbol = checker.getSymbolAtLocation(node.name);
if (symbol) {
const name = className + ':' + symbol.getName();
const docs = getDocs(symbol);
const code = node.getText();
addRow(fileName, name, code, docs);
}
}
function addInterface(fileName, node) {
const symbol = checker.getSymbolAtLocation(node.name);
if (symbol) {
const name = symbol.getName();
const docs = getDocs(symbol);
const code = `interface ${name} {}`;
addRow(fileName, name, code, docs);
node.members.forEach(m => addInterfaceMember(fileName, name, m));
}
}
function addInterfaceMember(fileName, interfaceName, node) {
if (!ts.isPropertySignature(node) || !ts.isMethodSignature(node)) {
return;
}
const symbol = checker.getSymbolAtLocation(node.name);
if (symbol) {
const name = interfaceName + ':' + symbol.getName();
const docs = getDocs(symbol);
const code = node.getText();
addRow(fileName, name, code, docs);
}
}
function getDocs(symbol) {
return ts.displayPartsToString(symbol.getDocumentationComment(checker));
}
for (const fileName of config.fileNames) {
const sourceFile = program.getSourceFile(fileName);
const visitNode = node => {
if (ts.isFunctionDeclaration(node)) {
addFunction(fileName, node);
} else if (ts.isClassDeclaration(node)) {
addClass(fileName, node);
} else if (ts.isInterfaceDeclaration(node)) {
addInterface(fileName, node);
}
ts.forEachChild(node, visitNode);
};
ts.forEachChild(sourceFile, visitNode);
}
for (const row of rows) {
row.combined = '';
if (row.docs) {
row.combined += `Code documentation: ${row.docs}; `;
}
row.combined += `Code: ${row.code}; Name: ${row.name};`;
}
const output = csv.stringify(rows, {
header: true
});
console.log(output);
from io import StringIO
from subprocess import PIPE, run
from pandas import read_csv
from openai.embeddings_utils import get_embedding as _get_embedding
from tenacity import wait_random_exponential, stop_after_attempt
get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))
if __name__ == '__main__':
# 1
result = run(['node', 'code-to-csv.js'], stdout=PIPE, stderr=PIPE, universal_newlines=True)
if result.returncode != 0:
raise RuntimeError(result.stderr)
# 2
db = read_csv(StringIO(result.stdout))
# 3
db['embedding'] = db['combined'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002'))
# 4
db.to_csv("search_db.csv", index=False)
import sys
import numpy as np
from pandas import read_csv
from openai.embeddings_utils import cosine_similarity, get_embedding as _get_embedding
from tenacity import stop_after_attempt, wait_random_exponential
get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))
def search(db, query):
# 4
query_embedding = get_embedding(query, engine='text-embedding-ada-002')
# 5
db['similarities'] = db.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
# 6
db.sort_values('similarities', ascending=False, inplace=True)
result = db.head(3)
text = ""
for row in result.itertuples(index=False):
score=round(row.similarities, 3)
if type(row.docs) == str:
text += '/**\n * {docs}\n */\n'.format(docs='\n * '.join(row.docs.split('\n')))
text += '{code}\n\n'.format(code='\n'.join(row.code.split('\n')[:7]))
text += '[score={score}] {file_name}:{name}\n'.format(score=score, file_name=row.file_name, name=row.name)
text += '-' * 70 + '\n\n'
return text
if __name__ == '__main__':
# 1
db = read_csv('search_db.csv')
# 2
db['embedding'] = db.embedding.apply(eval).apply(np.array)
query = sys.argv[1]
print('')
# 3
print(search(db, query))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment