Skip to content

Instantly share code, notes, and snippets.

@rsaryev
Last active August 15, 2023 22:36
Show Gist options
  • Save rsaryev/40ce5d844370faed07a9d57f553f1698 to your computer and use it in GitHub Desktop.
Save rsaryev/40ce5d844370faed07a9d57f553f1698 to your computer and use it in GitHub Desktop.
similar professions based on a vector embeddings model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:40:59.311108Z",
"end_time": "2023-06-08T11:41:01.160608Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import pickle\n",
"import os\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"\n",
"if os.getenv(\"OPENAI_API_KEY\") is not None:\n",
" import openai\n",
"\n",
" openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"else:\n",
" print(\"OPENAI_API_KEY environment variable not found\")\n",
"\n",
"from openai.embeddings_utils import (\n",
" get_embedding,\n",
" distances_from_embeddings,\n",
" indices_of_nearest_neighbors_from_distances\n",
")\n",
"\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.157258Z",
"end_time": "2023-06-08T11:41:01.167264Z"
}
},
"outputs": [],
"source": [
"dataset_path = \"../data/professions.csv\"\n",
"df = pd.read_csv(dataset_path)\n",
"professions = df[\"en\"].tolist()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.159356Z",
"end_time": "2023-06-08T11:41:01.167455Z"
}
},
"outputs": [],
"source": [
"embedding_cache_path = \"../cache/recommendations_embeddings_cache.pkl\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.167823Z",
"end_time": "2023-06-08T11:41:01.171016Z"
}
},
"outputs": [],
"source": [
"if os.path.isfile(embedding_cache_path):\n",
" embedding_cache = pd.read_pickle(embedding_cache_path)\n",
"else:\n",
" embedding_cache = {}\n",
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
" pickle.dump(embedding_cache, embedding_cache_file)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.176021Z",
"end_time": "2023-06-08T11:41:01.177402Z"
}
},
"outputs": [],
"source": [
"def embedding_from_profession(\n",
" profession: str,\n",
" model: str = EMBEDDING_MODEL,\n",
" embedding_cache=embedding_cache\n",
") -> list:\n",
" if (profession, model) not in embedding_cache.keys():\n",
" embedding_cache[(profession, model)] = get_embedding(profession, model)\n",
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
" pickle.dump(embedding_cache, embedding_cache_file)\n",
" return embedding_cache[(profession, model)]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.178729Z",
"end_time": "2023-06-08T11:41:01.182543Z"
}
},
"outputs": [],
"source": [
"def get_similar_professions(query: str, k: int = 10) -> dict:\n",
" embeddings = [embedding_from_profession(title, model=EMBEDDING_MODEL) for title in professions]\n",
" query_embedding = embeddings[professions.index(query)]\n",
" distances = distances_from_embeddings(query_embedding, embeddings, distance_metric=\"cosine\")\n",
" indices = indices_of_nearest_neighbors_from_distances(distances)\n",
"\n",
" return {professions[index]: 1 - distances[index] for index in indices[:k]}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "{'Waiter': 1,\n 'Restaurant employee': 0.9002262329129079,\n 'Assistant cook': 0.8755513786133987,\n 'Cashier-cook': 0.874830050166545,\n 'Cashier': 0.8744767651558191,\n 'Hostess': 0.8727183085593714,\n 'Barista': 0.8646971756662765,\n 'Cook': 0.8579846932425619,\n 'Security guard': 0.8563417133719982,\n 'Dishwasher': 0.8521875779485968}"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_similar_professions(\"Waiter\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-06-08T11:41:01.184028Z",
"end_time": "2023-06-08T11:41:01.229569Z"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:14.160396Z",
"end_time": "2023-06-08T11:41:14.527895Z"
}
},
"outputs": [],
"source": [
"data = []\n",
"for i, profession in enumerate(professions):\n",
" similar_professions = get_similar_professions(profession)\n",
" root_profession_id = df.iloc[i][\"id\"]\n",
" for similar_profession in similar_professions.keys():\n",
" similar_profession_id = df[df[\"en\"] == similar_profession][\"id\"].values[0]\n",
" data.append({\n",
" \"root_profession_id\": root_profession_id,\n",
" \"root_profession\": profession,\n",
" \"similar_profession_id\": similar_profession_id,\n",
" \"similar_profession\": similar_profession,\n",
" \"score\": similar_professions[similar_profession],\n",
" })\n",
"\n",
"dataframe = pd.DataFrame(data, columns=[\"root_profession_id\", \"root_profession\", \"similar_profession_id\",\n",
" \"similar_profession\", \"score\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"start_time": "2023-06-08T11:41:16.037701Z",
"end_time": "2023-06-08T11:41:16.069886Z"
}
},
"outputs": [],
"source": [
"dataframe.to_csv(\"../data/similar_professions.csv\", index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment