Last active
May 19, 2023 15:03
-
-
Save JasonTam/89ff752d7e35ec17d730c87aea96c19b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Basic Negative Sampling Implementations" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "from scipy import stats\n", | |
| "from functools import partial\n", | |
| "np.random.seed(322)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>n_items</th>\n", | |
| " <th>pos_inds</th>\n", | |
| " <th>n_samp</th>\n", | |
| " <th>frac_pos</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>case0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>[3, 9, 22]</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0.12000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>case1</th>\n", | |
| " <td>25</td>\n", | |
| " <td>[3, 9, 22]</td>\n", | |
| " <td>100</td>\n", | |
| " <td>0.12000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>case2</th>\n", | |
| " <td>25000</td>\n", | |
| " <td>[3, 9, 22]</td>\n", | |
| " <td>100</td>\n", | |
| " <td>0.00012</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>case3</th>\n", | |
| " <td>25</td>\n", | |
| " <td>[3, 9, 22]</td>\n", | |
| " <td>10000</td>\n", | |
| " <td>0.12000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>case4</th>\n", | |
| " <td>25</td>\n", | |
| " <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...</td>\n", | |
| " <td>100</td>\n", | |
| " <td>0.88000</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " n_items pos_inds n_samp \\\n", | |
| "case0 25 [3, 9, 22] 1 \n", | |
| "case1 25 [3, 9, 22] 100 \n", | |
| "case2 25000 [3, 9, 22] 100 \n", | |
| "case3 25 [3, 9, 22] 10000 \n", | |
| "case4 25 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... 100 \n", | |
| "\n", | |
| " frac_pos \n", | |
| "case0 0.12000 \n", | |
| "case1 0.12000 \n", | |
| "case2 0.00012 \n", | |
| "case3 0.12000 \n", | |
| "case4 0.88000 " | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "cases_df = pd.DataFrame.from_dict({\n", | |
| " 'case0': {\n", | |
| " 'n_items': 25,\n", | |
| " 'pos_inds': np.array([3, 9, 22]),\n", | |
| " 'n_samp': 1,\n", | |
| " },\n", | |
| " 'case1': {\n", | |
| " 'n_items': 25,\n", | |
| " 'pos_inds': np.array([3, 9, 22]),\n", | |
| " 'n_samp': 100,\n", | |
| " },\n", | |
| " 'case2': {\n", | |
| " 'n_items': 25_000,\n", | |
| " 'pos_inds': np.array([3, 9, 22]),\n", | |
| " 'n_samp': 100,\n", | |
| " },\n", | |
| " 'case3': {\n", | |
| " 'n_items': 25,\n", | |
| " 'pos_inds': np.array([3, 9, 22]),\n", | |
| " 'n_samp': 10_000,\n", | |
| " },\n", | |
| " 'case4': {\n", | |
| " 'n_items': 25,\n", | |
| " 'pos_inds': np.arange(25-3),\n", | |
| " 'n_samp': 100,\n", | |
| " },\n", | |
| "}, orient='index')\n", | |
| "cases_df['frac_pos'] = cases_df['pos_inds'].map(len) / cases_df['n_items']\n", | |
| "cases_df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Incremental Guess and Check" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def negsamp_incr(pos_check, pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Guess and check with arbitrary positivity check\n", | |
| " \"\"\"\n", | |
| " neg_inds = []\n", | |
| " while len(neg_inds) < n_samp:\n", | |
| " raw_samp = np.random.randint(0, n_items)\n", | |
| " if not pos_check(raw_samp, pos_inds):\n", | |
| " neg_inds.append(raw_samp)\n", | |
| " return neg_inds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def negsamp_incr_naive(pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Guess and check with list membership\n", | |
| " \"\"\"\n", | |
| " pos_check = lambda raw_samp, pos_inds: raw_samp in pos_inds\n", | |
| " return negsamp_incr(pos_check, pos_inds, n_items, n_samp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def negsamp_incr_set(pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Guess and check with hashtable membership\n", | |
| " \"\"\"\n", | |
| " pos_inds = set(pos_inds)\n", | |
| " pos_check = lambda raw_samp, pos_inds: raw_samp in pos_inds\n", | |
| " return negsamp_incr(pos_check, pos_inds, n_items, n_samp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from bisect import bisect_left\n", | |
| "\n", | |
| "def bsearch_in(search_val, val_arr):\n", | |
| " i = bisect_left(val_arr, search_val)\n", | |
| " return i != len(val_arr) and val_arr[i] == search_val\n", | |
| " \n", | |
| "def negsamp_incr_bsearch(pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Guess and check with binary search\n", | |
| " `pos_inds` is assumed to be ordered\n", | |
| " \"\"\"\n", | |
| " pos_check = bsearch_in\n", | |
| " return negsamp_incr(pos_check, pos_inds, n_items, n_samp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Vectorized Guess and Check" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def negsamp_vectorized_bsearch(pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Guess and check vectorized\n", | |
| " Assumes that we are allowed to potentially \n", | |
| " return less than n_samp samples\n", | |
| " \"\"\"\n", | |
| " raw_samps = np.random.randint(0, n_items, size=n_samp)\n", | |
| " ss = np.searchsorted(pos_inds, raw_samps)\n", | |
| " pos_mask = raw_samps == np.take(pos_inds, ss, mode='clip')\n", | |
| " neg_inds = raw_samps[~pos_mask]\n", | |
| " return neg_inds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Vectorized Pre-verified" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def negsamp_vectorized_bsearch_preverif(pos_inds, n_items, n_samp=32):\n", | |
| " \"\"\" Pre-verified with binary search\n", | |
| " `pos_inds` is assumed to be ordered\n", | |
| " \"\"\"\n", | |
| " raw_samp = np.random.randint(0, n_items - len(pos_inds), size=n_samp)\n", | |
| " pos_inds_adj = pos_inds - np.arange(len(pos_inds))\n", | |
| " neg_inds = raw_samp + np.searchsorted(pos_inds_adj, raw_samp, side='right')\n", | |
| " return neg_inds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Sanity Checking" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "strategies = [\n", | |
| " negsamp_incr_naive,\n", | |
| " negsamp_incr_set,\n", | |
| " negsamp_incr_bsearch,\n", | |
| " negsamp_vectorized_bsearch,\n", | |
| " negsamp_vectorized_bsearch_preverif,\n", | |
| "]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Just a quick sanity check\n", | |
| "def is_valid(samps, n_items, pos_inds):\n", | |
| " densities, bin_edges = np.histogram(samps, bins=range(n_items+1), density=True)\n", | |
| " neg_inds = np.array(list(set(np.arange(n_items)) - set(pos_inds)))\n", | |
| " # Should not have any positives sampled as negatives\n", | |
| " is_non_pos = not densities[pos_inds].any()\n", | |
| " # Distribution should be ~uniform\n", | |
| " is_uniform = (stats.chisquare(densities[neg_inds]).pvalue > 0.95 or \n", | |
| " len(samps)<=1 # let's be forgiving if n_samp=1\n", | |
| " )\n", | |
| " return is_non_pos and is_uniform\n", | |
| "\n", | |
| "for case_name, row in cases_df.iterrows():\n", | |
| " for strat in strategies:\n", | |
| " samps = strat(row['pos_inds'], row['n_items'], row['n_samp'])\n", | |
| " assert is_valid(samps, row['n_items'], row['pos_inds'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Timed Runs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Run our strategies against our cases\n", | |
| "times_d = {}\n", | |
| "for case_name, row in cases_df.iterrows():\n", | |
| " d = {}\n", | |
| " for strat in strategies:\n", | |
| " r = %timeit -o -q strat(row['pos_inds'], row['n_items'], row['n_samp'])\n", | |
| " d[strat.__name__] = r.average\n", | |
| " times_d[case_name] = d" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>case0</th>\n", | |
| " <th>case1</th>\n", | |
| " <th>case2</th>\n", | |
| " <th>case3</th>\n", | |
| " <th>case4</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_naive</th>\n", | |
| " <td>0.000025</td>\n", | |
| " <td>0.000557</td>\n", | |
| " <td>0.000495</td>\n", | |
| " <td>0.051792</td>\n", | |
| " <td>0.003816</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_set</th>\n", | |
| " <td>0.000021</td>\n", | |
| " <td>0.000185</td>\n", | |
| " <td>0.000171</td>\n", | |
| " <td>0.016344</td>\n", | |
| " <td>0.001264</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_bsearch</th>\n", | |
| " <td>0.000020</td>\n", | |
| " <td>0.000255</td>\n", | |
| " <td>0.000207</td>\n", | |
| " <td>0.023141</td>\n", | |
| " <td>0.001898</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_vectorized_bsearch</th>\n", | |
| " <td>0.000025</td>\n", | |
| " <td>0.000027</td>\n", | |
| " <td>0.000027</td>\n", | |
| " <td>0.000214</td>\n", | |
| " <td>0.000029</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_vectorized_bsearch_preverif</th>\n", | |
| " <td>0.000024</td>\n", | |
| " <td>0.000025</td>\n", | |
| " <td>0.000025</td>\n", | |
| " <td>0.000203</td>\n", | |
| " <td>0.000025</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " case0 case1 case2 case3 \\\n", | |
| "negsamp_incr_naive 0.000025 0.000557 0.000495 0.051792 \n", | |
| "negsamp_incr_set 0.000021 0.000185 0.000171 0.016344 \n", | |
| "negsamp_incr_bsearch 0.000020 0.000255 0.000207 0.023141 \n", | |
| "negsamp_vectorized_bsearch 0.000025 0.000027 0.000027 0.000214 \n", | |
| "negsamp_vectorized_bsearch_preverif 0.000024 0.000025 0.000025 0.000203 \n", | |
| "\n", | |
| " case4 \n", | |
| "negsamp_incr_naive 0.003816 \n", | |
| "negsamp_incr_set 0.001264 \n", | |
| "negsamp_incr_bsearch 0.001898 \n", | |
| "negsamp_vectorized_bsearch 0.000029 \n", | |
| "negsamp_vectorized_bsearch_preverif 0.000025 " | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Average times\n", | |
| "times_df = pd.DataFrame.from_dict(times_d, orient='columns')\\\n", | |
| " .reindex([s.__name__ for s in strategies])\n", | |
| "times_df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>case0</th>\n", | |
| " <th>case1</th>\n", | |
| " <th>case2</th>\n", | |
| " <th>case3</th>\n", | |
| " <th>case4</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_naive</th>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_set</th>\n", | |
| " <td>1.233551</td>\n", | |
| " <td>3.007306</td>\n", | |
| " <td>2.891718</td>\n", | |
| " <td>3.168921</td>\n", | |
| " <td>3.018263</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_incr_bsearch</th>\n", | |
| " <td>1.274717</td>\n", | |
| " <td>2.186638</td>\n", | |
| " <td>2.392386</td>\n", | |
| " <td>2.238093</td>\n", | |
| " <td>2.010704</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_vectorized_bsearch</th>\n", | |
| " <td>1.022648</td>\n", | |
| " <td>20.735678</td>\n", | |
| " <td>18.502182</td>\n", | |
| " <td>242.499565</td>\n", | |
| " <td>131.991317</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>negsamp_vectorized_bsearch_preverif</th>\n", | |
| " <td>1.079992</td>\n", | |
| " <td>21.908844</td>\n", | |
| " <td>19.691580</td>\n", | |
| " <td>255.305570</td>\n", | |
| " <td>150.949921</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " case0 case1 case2 \\\n", | |
| "negsamp_incr_naive 1.000000 1.000000 1.000000 \n", | |
| "negsamp_incr_set 1.233551 3.007306 2.891718 \n", | |
| "negsamp_incr_bsearch 1.274717 2.186638 2.392386 \n", | |
| "negsamp_vectorized_bsearch 1.022648 20.735678 18.502182 \n", | |
| "negsamp_vectorized_bsearch_preverif 1.079992 21.908844 19.691580 \n", | |
| "\n", | |
| " case3 case4 \n", | |
| "negsamp_incr_naive 1.000000 1.000000 \n", | |
| "negsamp_incr_set 3.168921 3.018263 \n", | |
| "negsamp_incr_bsearch 2.238093 2.010704 \n", | |
| "negsamp_vectorized_bsearch 242.499565 131.991317 \n", | |
| "negsamp_vectorized_bsearch_preverif 255.305570 150.949921 " | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Speedup factors compared to our most naive strategy\n", | |
| "speedup_df = times_df.loc['negsamp_incr_naive'] / times_df\n", | |
| "speedup_df" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment