Last active
July 10, 2025 18:46
-
-
Save darcyabjones/ace0934ae19c56ed4babc50167f282b0 to your computer and use it in GitHub Desktop.
A quick introduction to learning to rank models.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "sensitive-recognition", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "import seaborn as sns\n", | |
| "from matplotlib import pyplot as plt\n", | |
| "from statistics import mean" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "precise-colleague", | |
| "metadata": {}, | |
| "source": [ | |
| "# A brief intro to gradient descent and neural networks\n", | |
| "\n", | |
| "Feel free to skip." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "perceived-community", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def log_loss(true, pred):\n", | |
| " return ((true * np.log(pred)) + ((1 - true) * np.log(1 - pred))).sum() / -true.shape[0]\n", | |
| "\n", | |
| "def sigmoid(pred):\n", | |
| " return 1 / (1 + np.exp(-pred))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "great-shopper", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "linspace = np.linspace(-5, 5, 100)\n", | |
| "ax = sns.scatterplot(y=sigmoid(linspace), x=linspace) \n", | |
| "ax.set_ylabel(\"f(x)\")\n", | |
| "ax.set_xlabel(\"x\");" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "distinct-ethiopia", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "linspace = sigmoid(np.linspace(-5, 5, 100))\n", | |
| "true_class = 0.0\n", | |
| "losses = np.apply_along_axis(lambda x: log_loss(np.array([true_class]), x), 0, linspace.reshape(1, -1))\n", | |
| "ax = sns.scatterplot(y=losses, x=linspace)\n", | |
| "ax.set_xlabel(\"Prediction\")\n", | |
| "ax.set_ylabel(\"Loss\");" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "protected-blank", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "coef = np.linspace(0, 10, 100)\n", | |
| "loss = (coef - 5) ** 2\n", | |
| "\n", | |
| "line = lambda x: 2 * (x - 5)\n", | |
| "ax = sns.scatterplot(x=coef, y=loss)\n", | |
| "plt.plot([5, 10], line(np.array([5, 10])), color=\"red\", label=\"derivative\")\n", | |
| "plt.legend()\n", | |
| "ax.set_xlabel(\"Coefficient\")\n", | |
| "ax.set_ylabel(\"Loss\");" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "advance-shannon", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0.8217988 , 0.91751578],\n", | |
| " [-0.42918603, -0.56585462],\n", | |
| " [-0.74825694, 2.65235002],\n", | |
| " [-0.70705092, -0.6665801 ],\n", | |
| " [ 0.85580474, 0.89543249]])" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = np.random.normal(size=200).reshape(-1, 2)\n", | |
| "X[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "thick-ocean", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1., 0., 1., 0., 1.])" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y = np.round(sigmoid((2 * X[:, 0] + 3 * X[:, 1])))\n", | |
| "y[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "advanced-bread", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "initial weights [0.08587948 0.0992165 ] \n", | |
| "\n", | |
| "loss: 0.6327634538180974\n", | |
| "weights: [0.10381509 0.12992812]\n", | |
| "loss: 0.6208585100088778\n", | |
| "weights: [0.12138765 0.15986252]\n", | |
| "loss: 0.6095124985129491\n", | |
| "weights: [0.1386076 0.18904531]\n", | |
| "loss: 0.5986938306967757\n", | |
| "weights: [0.15548549 0.21750214]\n", | |
| "loss: 0.5883723485148004\n", | |
| "weights: [0.1720319 0.24525852]\n", | |
| "loss: 0.5785193747199389\n", | |
| "weights: [0.18825734 0.27233962]\n", | |
| "loss: 0.5691077334359955\n", | |
| "weights: [0.20417221 0.29877017]\n", | |
| "loss: 0.5601117464238963\n", | |
| "weights: [0.21978674 0.32457433]\n", | |
| "loss: 0.551507209956081\n", | |
| "weights: [0.23511096 0.34977561]\n", | |
| "loss: 0.5432713566872078\n", | |
| "weights: [0.25015464 0.37439684]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learning_rate = 0.1\n", | |
| "nepoch = 10\n", | |
| "\n", | |
| "w = np.random.normal(scale=0.1, size=2).reshape(1, 2)\n", | |
| "print(\"initial weights\", w[0], \"\\n\")\n", | |
| "for i in range(nepoch):\n", | |
| " preds = sigmoid(w.dot(X.T))[0]\n", | |
| " gradient = np.sum((y - preds).reshape(-1, 1) * X, axis=0).reshape(1, 2) / X.shape[0]\n", | |
| " w += learning_rate * gradient\n", | |
| " #print(gradient)\n", | |
| "\n", | |
| " preds = sigmoid(w.dot(X.T))[0]\n", | |
| " loss = log_loss(y, preds)\n", | |
| "\n", | |
| " print(\"loss:\", loss)\n", | |
| " print(\"weights:\", w[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "statutory-rotation", | |
| "metadata": { | |
| "slideshow": { | |
| "slide_type": "slide" | |
| } | |
| }, | |
| "source": [ | |
| "# Learning to rank\n", | |
| "\n", | |
| "- A supervised learning task\n", | |
| "- Attempts to optimally order a list (or lists) of items.\n", | |
| "- Usually used for information retrieval: search engines, recommenders.\n", | |
| "\n", | |
| "There are currently 3 general kinds of methods for this task.\n", | |
| "\n", | |
| "- Point-wise\n", | |
| " - Basically just regression or classification\n", | |
| " - Considers each list item independently of others\n", | |
| " - Rankprop (MSE regression with adjustments)\n", | |
| " - https://proceedings.neurips.cc/paper/1995/hash/36a16a2505369e0c922b6ea7a23a56d2-Abstract.html\n", | |
| " - P-rank (ordinal regression)\n", | |
| " - https://papers.nips.cc/paper/2001/file/5531a5834816222280f20d1ef9e95f69-Paper.pdf\n", | |
| " - mcrank (multiple classification/ordinal classification)\n", | |
| " - http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.992.5789&rep=rep1&type=pdf\n", | |
| "\n", | |
| "- Pair-wise\n", | |
| " - Consider pairs of items and decide which should be higher\n", | |
| " - RankingSVM/RankSVM/SVMrank\n", | |
| " - https://www.hindawi.com/journals/cin/2017/4629534/\n", | |
| " - https://www.cs.cornell.edu/people/tj/svm_light/svm_rank.html\n", | |
| " - https://doi.org/10.1007/978-3-642-01307-2_39\n", | |
| " - RankBoost\n", | |
| " - https://www.jmlr.org/papers/volume4/freund03a/freund03a.pdf\n", | |
| " - Ranknet\n", | |
| " - https://www.microsoft.com/en-us/research/wp-content/uploads/2005/08/icml_ranking.pdf\n", | |
| " - LambdaRank/LambdaMART\n", | |
| " - https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/lambdarank.pdf\n", | |
| "\n", | |
| "- List-wise\n", | |
| " - Optimise the ordering of whole lists of items at a time\n", | |
| " - Requires numerous lists of items\n", | |
| " - Often used for information retrieval (match relevance) or product display order (likelihood of sale)\n", | |
| " - e.g. LambdaRank/LambdaMART\n", | |
| " - softRank\n", | |
| " - listRank\n", | |
| " - listMLE\n", | |
| " \n", | |
| " \n", | |
| " \n", | |
| "Recommended resource from Guido Zuccon\n", | |
| "- Hang Li, \"Learning to Rank for Information Retrieval and Natural Language Processing\" https://www.morganclaypool.com/doi/abs/10.2200/S00607ED2V01Y201410HLT026\n", | |
| "\n", | |
| " \n", | |
| "\n", | |
| "LTR is closely related to (and is incorporated in) some recommender system problems, but the literature doesn't seem to intersect much.\n", | |
| "Lots more information available on recommenders so I'd add it as a search term if you're trying to learn this stuff.\n", | |
| "\n", | |
| "Collaborative filtering might form part of your initial selection of documents to present.\n", | |
| "Ranking and reranking in recommenders often don't use the term \"learning to rank\" so look around.\n", | |
| "\n", | |
| "- https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf\n", | |
| "- https://arxiv.org/pdf/1812.04109.pdf\n", | |
| "- https://arxiv.org/pdf/1904.06813.pdf\n", | |
| "- https://arxiv.org/abs/1804.05936" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "french-macro", | |
| "metadata": { | |
| "slideshow": { | |
| "slide_type": "slide" | |
| } | |
| }, | |
| "source": [ | |
| "## Data and labels\n", | |
| "\n", | |
| "Selection of features usually some indicator of relevance to a query term.\n", | |
| "\n", | |
| "- Distance/Similarity of word/sentence embeddings\n", | |
| "- PageRank\n", | |
| "- Clicks\n", | |
| "- User history\n", | |
| "\n", | |
| "Could also just be general features used for regression or classification and a query term as categorical factor.\n", | |
| "\n", | |
| "> A lot of feature engineering has historically been done in search engines (so you can imagine how happy people are to be moving to DL).\n", | |
| "> Complex indexing time features are often computed.\n", | |
| "> Complex query time features are more challenging, as they add to the query latency.\n", | |
| "> So often a cascade of rankers is employed (e.g. Twitter uses 3 or 4 depths of cascades) with the complex query-time features being computed only at later stages.\n", | |
| ">\n", | |
| "> \\- Guido Zuccon\n", | |
| "\n", | |
| "Labels used are relevance or priority scores for each item in your training/test datasets (e.g. 0, 1, 2, 3, 4).\n", | |
| "\n", | |
| "Can also be pairwise decisions on whether one should be higher than the other (like a binary classifier for each pair)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "informed-window", | |
| "metadata": {}, | |
| "source": [ | |
| "## Why do I care about this stuff?\n", | |
| "\n", | |
| "\n", | |
| "#### Protein functional annotation\n", | |
| "\n", | |
| "Came about LTR when I was looking at new gene ontology term prediction methods and came across [GOLabeler](https://academic.oup.com/bioinformatics/article/34/14/2465/4924212).\n", | |
| "GoLabeler takes a protein sequence, performs several analyses and uses LTR to order the resulting GO terms by relevance (i.e. a multi-label problem).\n", | |
| "\n", | |
| "#### Disease resistance scoring\n", | |
| "\n", | |
| "Quantifying (phenotypically) complex traits and combinations of traits can be difficult.\n", | |
| "Scoring methods have been used for years, and researchers are hesitant/unable to move into more automated areas.\n", | |
| "\n", | |
| "Below is a typical scoring key used for a wheat necrotrophic pathogen (Ptr/Yellow spot)...\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "> Dinglasan, E., Godwin, I.D., Mortlock, M.Y. et al. Resistance to yellow spot in wheat grown under accelerated growth > conditions. Euphytica 209, 693–707 (2016). https://doi.org/10.1007/s10681-016-1660-z\n", | |
| "\n", | |
| "\n", | |
| "It's difficult to assign a distinct quantitative score of disease severity (or even an ordinal one), but it's much easier to say that one leaf is more severely diseased than another.\n", | |
| "\n", | |
| "#### Protein candidate ranking.\n", | |
| "\n", | |
| "There is a class of proteins called \"effectors\" that we're interested in.\n", | |
| "Effectors share little sequence similarity but have some common characteristics related to size, charge, motifs etc.\n", | |
| "But not all indicators of these characteristics are perfect (e.g signal peptide prediction), and not all effectors have these characteristics (large size).\n", | |
| "\n", | |
| "The idea is to use LTR to learn a static sorting function, so that things with more of the indicator characteristics are nearer the top of the list in a big spreadsheet but we're not \"throwing out\" good candidates (jut a bit further down the list).\n", | |
| "\n", | |
| "Wrote a pipeline called [predector](https://github.com/ccdmb/predector) based on this idea, and it's super useful.\n", | |
| "\n", | |
| "\n", | |
| "#### A few other interesting cases\n", | |
| "\n", | |
| "- Drug selection for individual patients based on genomic information\n", | |
| " - https://ieeexplore.ieee.org/abstract/document/8395000?casa_token=U1XwW3DNc34AAAAA:gLBTlk1uCfJQZpxTi_emFrCGHB23WiFNL2engdCvI8yrQMTFao3rslhgvs3kLFfZ48k113I9WF7u\n", | |
| "- Ranking individual treatment effects (Uplift modelling)\n", | |
| " - https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9311871&tag=1\n", | |
| "- Diseases name normalisation/standardisation\n", | |
| " - https://academic.oup.com/bioinformatics/article/29/22/2909/312804?login=true\n", | |
| "\n", | |
| "\n", | |
| "---\n", | |
| "\n", | |
| "\n", | |
| "Adjusting the objective of the modelling rather than changing requirements of researchers/input:\n", | |
| "\n", | |
| "- lower barrier to uptake by researchers (easier data generation, results closer to what they expect)\n", | |
| "- viable intermediate tools until traits can be broken down into separate components\n", | |
| "- potential for interactive reinforcement/labelling" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "welsh-ecology", | |
| "metadata": { | |
| "slideshow": { | |
| "slide_type": "slide" | |
| } | |
| }, | |
| "source": [ | |
| "## Metrics\n", | |
| "\n", | |
| "Specialised evaluation metrics are all list-wise.\n", | |
| "Point- and pair-wise algorithms use a different loss function during training.\n", | |
| "\n", | |
| "Binary relevance (0 or 1):\n", | |
| "- Mean average precision (MAP) - Like an AUC for ranking.\n", | |
| "- Mean reciprocal rank (MRR) - Position of first relevant item\n", | |
| "- accuracy@k (or precision, recall etc)\n", | |
| "\n", | |
| "Multi-level relevance (0, 1, 2, ...):\n", | |
| "- Normalised DCG (NDCG)\n", | |
| "- Expected reciprocal rank (ERR) https://dl.acm.org/doi/pdf/10.1145/1645953.1646033" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "correct-variation", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.6111111111111112" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def mrr(true, pred):\n", | |
| " \"\"\" Mean reciprocal rank\n", | |
| " \n", | |
| " all you care about is the top hit.\n", | |
| " \n", | |
| " true: array of 0 or 1 indicating relevance\n", | |
| " pred: scoring function producing sorted list in descending order.\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " import numpy as np\n", | |
| " from statistics import mean\n", | |
| " \n", | |
| " reciprocal_ranks = []\n", | |
| " \n", | |
| " for i in range(true.shape[0]):\n", | |
| " order = np.argsort(pred[i])[::-1]\n", | |
| " ranked = true[i][order]\n", | |
| "\n", | |
| " # Find the position of the first element ranked as relevant\n", | |
| " list_rank = np.argmax(ranked > 0) + 1 # Correct for zero indexing.\n", | |
| "\n", | |
| " reciprocal_ranks.append(1 / list_rank)\n", | |
| "\n", | |
| " return mean(reciprocal_ranks) # Mean of list scores\n", | |
| "\n", | |
| "p = np.array([\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1]\n", | |
| "])\n", | |
| "\n", | |
| "t = np.array([\n", | |
| " [0, 0, 1],\n", | |
| " [0, 1, 1],\n", | |
| " [1, 1, 0]\n", | |
| "])\n", | |
| "\n", | |
| "mrr(t, p)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "designed-blowing", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.33333333, 0.58333333, 1. ])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def average_precision(true, pred):\n", | |
| " import numpy as np\n", | |
| " \n", | |
| " average_precisions = []\n", | |
| " \n", | |
| " for i in range(true.shape[0]):\n", | |
| " order = np.argsort(pred[i])[::-1]\n", | |
| " ranked = true[i][order]\n", | |
| " \n", | |
| " running_nrelevant = np.cumsum(ranked)\n", | |
| " running_rank = np.arange(1, ranked.shape[0] + 1)\n", | |
| " running_precision = running_nrelevant[ranked > 0] / running_rank[ranked > 0]\n", | |
| "\n", | |
| " average_precisions.append(running_precision.mean())\n", | |
| "\n", | |
| " return np.array(average_precisions)\n", | |
| "\n", | |
| "p = np.array([\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1]\n", | |
| "])\n", | |
| "\n", | |
| "t = np.array([\n", | |
| " [0, 0, 1],\n", | |
| " [0, 1, 1],\n", | |
| " [1, 1, 0]\n", | |
| "])\n", | |
| "\n", | |
| "average_precision(t, p)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "coastal-fluid", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.6388888888888888" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def map_(true, pred):\n", | |
| " return average_precision(true, pred).mean()\n", | |
| "\n", | |
| "map_(t, p)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "global-requirement", | |
| "metadata": {}, | |
| "source": [ | |
| "### NDCG\n", | |
| "\n", | |
| "Applies a \"discount\" to scores based on sorted list index.\n", | |
| "There are multiple versions, and different implementations may yield different results.\n", | |
| "\n", | |
| "$ DCG@N = \\sum_{i=1}^{N} \\frac{y_i}{log(1+i)} $\n", | |
| "\n", | |
| "The one people actually use:\n", | |
| "\n", | |
| "$ DCG@N = \\sum_{i=1}^{N} \\frac{2^{y_i} - 1}{log(1+i)} $\n", | |
| "\n", | |
| "> The version with $2^{y{_i}}$ is often preferred as has a smoother gain curve for the different grades of relevance - Guido Zuccon\n", | |
| "\n", | |
| "\n", | |
| "To find the normalised DCG (NDCG) you divide the DCG by the ideal DCG.\n", | |
| "Ideal DCG is obtained by a perfect predicted ordering wrt your relevance labels.\n", | |
| "\n", | |
| "Great overview and analysis of NDCG in:\n", | |
| "\n", | |
| "Clémençon S., Lugosi G., Vayatis N. (2005) Ranking and Scoring Using Empirical Risk Minimization. In: Auer P., Meir R. (eds) Learning Theory. COLT 2005. Lecture Notes in Computer Science, vol 3559. Springer, Berlin, Heidelberg. https://doi.org/10.1007/11503415_1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "hollywood-flood", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.5 , 1.13092975, 1.63092975])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def discounted_cumulative_gain(true, pred, at=None, original=False):\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " dcgs = []\n", | |
| "\n", | |
| " for i in range(true.shape[0]):\n", | |
| " order = np.argsort(pred[i])[::-1]\n", | |
| " ranked = true[i][order]\n", | |
| "\n", | |
| " if at is not None:\n", | |
| " # Only evaluating in top k\n", | |
| " ranked = ranked[:at]\n", | |
| "\n", | |
| " if original:\n", | |
| " # Ranked score divided by rank of prediction\n", | |
| " # Note log2 means values with higher ranks (i.e. less relevant)\n", | |
| " # has weaker effect on score.\n", | |
| " ratios = ranked / np.log2(np.arange(1, ranked.shape[0] + 1) + 1)\n", | |
| " else:\n", | |
| " # Alternative ratio to make emphasis on most relevant stronger.\n", | |
| " ratios = ((2 ** ranked) - 1) / np.log2(np.arange(1, ranked.shape[0] + 1) + 1)\n", | |
| " dcgs.append(ratios.sum())\n", | |
| "\n", | |
| " return np.array(dcgs)\n", | |
| "\n", | |
| "p = np.array([\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1],\n", | |
| " [3, 2, 1]\n", | |
| "])\n", | |
| "\n", | |
| "t = np.array([\n", | |
| " [0, 0, 1],\n", | |
| " [0, 1, 1],\n", | |
| " [1, 1, 0]\n", | |
| "])\n", | |
| "\n", | |
| "discounted_cumulative_gain(t, p)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "sustained-player", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1. , 1.63092975, 1.63092975])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "discounted_cumulative_gain(t, t)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "cardiovascular-invite", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.5 , 0.6934264, 1. ])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "discounted_cumulative_gain(t, p) / discounted_cumulative_gain(t, t)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "flexible-laugh", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.7311421345390903" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def normalised_dcg(true, pred, at=None, original=False):\n", | |
| " # Actual dcg\n", | |
| " dcgs = discounted_cumulative_gain(true, pred, at, original=original)\n", | |
| "\n", | |
| " # Maximum possible dcg\n", | |
| " ideal_dcgs = discounted_cumulative_gain(true, true, at, original=original)\n", | |
| "\n", | |
| " # Handle edge case where ideal dcg is zero\n", | |
| " # This could happen if you have no relevant items in your list.\n", | |
| " dcgs[ideal_dcgs == 0.] = 0\n", | |
| " ideal_dcgs[ideal_dcgs == 0.] = 1\n", | |
| "\n", | |
| " return (dcgs / ideal_dcgs).mean()\n", | |
| "\n", | |
| "normalised_dcg(t, p)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "victorian-wallet", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "MRR 0.625\n", | |
| "MAP [0.8055555555555555, 0.325]\n", | |
| "MAP 0.5652777777777778\n", | |
| "NDCG [0.90602544 0.50126584]\n", | |
| "NDCG 0.7036456354382847\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "p = np.array([\n", | |
| " np.arange(5)[::-1],\n", | |
| " np.arange(5)[::-1]\n", | |
| "])\n", | |
| "\n", | |
| "t = np.array([\n", | |
| " [1, 0, 1, 1, 0],\n", | |
| " [0, 0, 0, 1, 1]\n", | |
| "])\n", | |
| "\n", | |
| "print(\"MRR\", mrr(t, p))\n", | |
| "\n", | |
| "print(\"MAP\", [map_(ti.reshape(1, -1), pi.reshape(1, -1)) for ti, pi in zip(t, p)])\n", | |
| "print(\"MAP\", map_(t, p))\n", | |
| "\n", | |
| "print(\"NDCG\", discounted_cumulative_gain(t, p) / discounted_cumulative_gain(t, t))\n", | |
| "print(\"NDCG\", normalised_dcg(t, p))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "according-banks", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.94881075, 0.59089746, 0.66534971, 0.88545988])" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# A multi-level example\n", | |
| "\n", | |
| "true = np.array([\n", | |
| " [3, 2, 3, 0, 1, 2],\n", | |
| " [0, 1, 2, 2, 3, 3],\n", | |
| " [0, 3, 0, 3, 0, 3],\n", | |
| " [3, 0, 3, 0, 3, 0],\n", | |
| "])\n", | |
| "\n", | |
| "pred = np.array([\n", | |
| " [5, 4, 3, 2, 1, 0],\n", | |
| " [5, 4, 3, 2, 1, 0],\n", | |
| " [5, 4, 3, 2, 1, 0],\n", | |
| " [5, 4, 3, 2, 1, 0],\n", | |
| "])\n", | |
| "\n", | |
| "discounted_cumulative_gain(true, pred) / discounted_cumulative_gain(true, true)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "inside-harmony", | |
| "metadata": {}, | |
| "source": [ | |
| "Sci-kit learn has an implementation of ndcg but it only uses the original score function (rather than the in-practise standard $2^i - 1$)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "searching-excellence", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.7931987348917845" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(true, pred, original=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "floral-nitrogen", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.7931987348917845" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from sklearn.metrics import ndcg_score\n", | |
| "ndcg_score(true, pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "adolescent-trader", | |
| "metadata": {}, | |
| "source": [ | |
| "Note that using priorities `1, 2, ...` instead of `0, 1, ...` will give different results.\n", | |
| "\n", | |
| "\"Correct\" usage is to start at 0 for irrelevant." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "abroad-fluid", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.7726294517134219" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(true, pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "northern-george", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.7955308580375091" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(true + 1, pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "inner-confidence", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "combined-equality", | |
| "metadata": {}, | |
| "source": [ | |
| "NDCG can be a bit difficult to interpret, and varies with length of list and labels used.\n", | |
| "The worst possible order will still have NDCG > 0." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "loving-cliff", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.59089746, 0.59089746, 0.55080959, 0.55080959])" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "discounted_cumulative_gain(true, (true.max() - true)) / discounted_cumulative_gain(true, true)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "specialized-leave", | |
| "metadata": {}, | |
| "source": [ | |
| "And the average NDCG given by randomly ordered lists can be quite high." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "optimum-jackson", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from numpy.random import default_rng" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "detailed-albert", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rng = default_rng()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "id": "parental-trinity", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[2, 3, 0, 4, 5, 1],\n", | |
| " [4, 5, 1, 3, 2, 0],\n", | |
| " [2, 5, 1, 0, 4, 3],\n", | |
| " [0, 3, 5, 4, 2, 1]])" | |
| ] | |
| }, | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pred2 = np.array([\n", | |
| " rng.choice(np.arange(6), 6, replace=False),\n", | |
| " rng.choice(np.arange(6), 6, replace=False),\n", | |
| " rng.choice(np.arange(6), 6, replace=False),\n", | |
| " rng.choice(np.arange(6), 6, replace=False)\n", | |
| "])\n", | |
| "pred2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "id": "unnecessary-beauty", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.62819453, 0.62819453, 0.87107854, 0.83854653])" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "discounted_cumulative_gain(true, pred2) / discounted_cumulative_gain(true, true)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "geological-burke", | |
| "metadata": {}, | |
| "source": [ | |
| "Rather than thinking about \"how much better than random we are are doing\", it might be better to think of \"how far away from the optimal solution are we?\".\n", | |
| "\n", | |
| "\n", | |
| "#### Expected reciprocal rank (ERR)\n", | |
| "\n", | |
| "People don't seem to use this so much.\n", | |
| "\n", | |
| "$$\n", | |
| "ERR = \\sum_{r=1}^{N} \\frac{1}{r}R_r\\prod_{i=1}^{r-1}(1-R_i)\n", | |
| "$$\n", | |
| "\n", | |
| "if $y_m$ is the maximum label value,\n", | |
| "\n", | |
| "$$\n", | |
| "R_i = \\frac{2^{y_i}-1}{2^{y_m}}\n", | |
| "$$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "incomplete-clinton", | |
| "metadata": {}, | |
| "source": [ | |
| "## Pointwise ranking - MCrank\n", | |
| "\n", | |
| "Learning to rank as \"multiple ordinal classification\" with boosted trees.\n", | |
| "\n", | |
| "http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.992.5789&rep=rep1&type=pdf\n", | |
| "\n", | |
| "- Show that there's a relationship between DCG and multiple classification.\n", | |
| "- Two multiple classification approaches (and a regression based solution) were quite successful compared to LambdaRank.\n", | |
| "\n", | |
| "Expected rank:\n", | |
| "\n", | |
| "1. Train a multi-class predictor (one class for each rank level: 0, 1, 2, ...).\n", | |
| "2. Multiply the classification probabilities with some monotonically increasing function of the rank levels.\n", | |
| "\n", | |
| "$$\n", | |
| "S_i = \\sum_{i=0}^{K - 1}\\bar{P}(i = k)f(k)\n", | |
| "$$\n", | |
| "\n", | |
| "3. Sort by $S_i$\n", | |
| "\n", | |
| "\n", | |
| "I'll define some example data that i'll use for all of the examples." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "massive-astronomy", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0.82418808, 0.479966 ],\n", | |
| " [ 1.17346801, 0.90904807],\n", | |
| " [-0.57172145, -0.10949727],\n", | |
| " [ 0.01902826, -0.94376106],\n", | |
| " [ 0.64057315, -0.78644317]])" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.random.seed(666)\n", | |
| "\n", | |
| "X = np.random.normal(size=20000).reshape(-1, 2)\n", | |
| "X[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "robust-pollution", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0., 1., 2., 3.])" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def quantise(arr, maxpriority=3):\n", | |
| " return (maxpriority * (arr - arr.min()) / (arr.max() - arr.min())).round()\n", | |
| "\n", | |
| "# This just makes sure all have all relevance levels\n", | |
| "y = np.zeros(10000)\n", | |
| "i = 0\n", | |
| "for j in range(100, 10001, 100):\n", | |
| " y[i:j] = quantise((5 * X[i:j, 0] + 1.5 * X[i:j, 1]))\n", | |
| " i = j\n", | |
| "\n", | |
| "np.unique(y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "invisible-complement", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0., 0., 1., 0.],\n", | |
| " [0., 0., 1., 0.],\n", | |
| " [0., 1., 0., 0.],\n", | |
| " [0., 1., 0., 0.]])" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from sklearn.preprocessing import OneHotEncoder\n", | |
| "from sklearn.ensemble import RandomForestClassifier\n", | |
| "\n", | |
| "\n", | |
| "y_ohe = OneHotEncoder(drop=None, sparse=False).fit_transform(y.reshape(-1, 1))\n", | |
| "y_ohe[:4,]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "operating-cattle", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0., 0., 1., 0.],\n", | |
| " [0., 0., 1., 0.],\n", | |
| " [0., 1., 0., 0.],\n", | |
| " [0., 1., 0., 0.],\n", | |
| " [0., 0., 1., 0.]])" | |
| ] | |
| }, | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = RandomForestClassifier(max_depth=4)\n", | |
| "model.fit(X, y_ohe)\n", | |
| "model.predict(X)[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "id": "seven-official", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1.97749313, 2.0729215 , 1.06639982, ..., 1.61335939, 1.00231679,\n", | |
| " 1.8529369 ])" | |
| ] | |
| }, | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "proba = model.predict_proba(X)\n", | |
| "preds = np.array([p[:, 1] for p in proba]).T\n", | |
| "preds = np.sum(preds * np.arange(4), axis=1)\n", | |
| "preds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "id": "cheap-pierre", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(3.0, 2.677313467063255),\n", | |
| " (3.0, 2.664500193638078),\n", | |
| " (3.0, 2.662675469878174),\n", | |
| " (3.0, 2.6614690082257124),\n", | |
| " (3.0, 2.656291430113934),\n", | |
| " (3.0, 2.6503448633931153),\n", | |
| " (3.0, 2.6503136455515315),\n", | |
| " (3.0, 2.648891314606223),\n", | |
| " (3.0, 2.648891314606223),\n", | |
| " (3.0, 2.648891314606223)]" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(zip(y, preds), key=lambda t: t[1], reverse=True)[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "id": "damaged-painting", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(0.0, 0.3397197218572301),\n", | |
| " (0.0, 0.335251017634349),\n", | |
| " (0.0, 0.33390507925899887),\n", | |
| " (0.0, 0.33352850084104024),\n", | |
| " (0.0, 0.33042760238756513),\n", | |
| " (0.0, 0.3258391060681249),\n", | |
| " (0.0, 0.32583766825874333),\n", | |
| " (0.0, 0.3244991526954715),\n", | |
| " (0.0, 0.3191379092655178),\n", | |
| " (0.0, 0.3114470766832209)]" | |
| ] | |
| }, | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(zip(y, preds), key=lambda t: t[1], reverse=True)[-10:]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "heated-decision", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9940772372662219" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(y.reshape(1, -1), preds.reshape(1, -1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "boxed-paradise", | |
| "metadata": {}, | |
| "source": [ | |
| "Ordinal classification:\n", | |
| "\n", | |
| "1. Train a classifier to predict cumulative probabilities i.e. $\\bar{P}(y_i \\le k)$\n", | |
| "2. Then find $\\bar{P}(y_i=k) = \\bar{P}(y_i \\le k) − \\bar{P}(y_i \\le k−1)$\n", | |
| "3. Calculate scores as before." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "naval-builder", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0., 0., 1., 1.],\n", | |
| " [0., 0., 1., 1.],\n", | |
| " [0., 1., 1., 1.],\n", | |
| " ...,\n", | |
| " [0., 0., 1., 1.],\n", | |
| " [0., 1., 1., 1.],\n", | |
| " [0., 0., 1., 1.]])" | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y_ohe2 = np.zeros((10000, 4))\n", | |
| "y_ohe2[:, 0] = (y <= 0).astype(int)\n", | |
| "y_ohe2[:, 1] = (y <= 1).astype(int)\n", | |
| "y_ohe2[:, 2] = (y <= 2).astype(int)\n", | |
| "y_ohe2[:, 3] = (y <= 3).astype(int)\n", | |
| "y_ohe2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "effective-capacity", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0., 0., 1., 1.],\n", | |
| " [0., 0., 1., 1.],\n", | |
| " [0., 1., 1., 1.],\n", | |
| " [0., 1., 1., 1.],\n", | |
| " [0., 0., 1., 1.]])" | |
| ] | |
| }, | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = RandomForestClassifier(max_depth=4)\n", | |
| "model.fit(X, y_ohe2)\n", | |
| "model.predict(X)[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "about-vision", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.00156794, 0.08226345, 0.93272601, 1. ],\n", | |
| " [0.0016215 , 0.05687561, 0.84154191, 1. ],\n", | |
| " [0.04926611, 0.87319555, 0.99565484, 1. ],\n", | |
| " ...,\n", | |
| " [0.01627381, 0.4333639 , 0.98515771, 1. ],\n", | |
| " [0.11859363, 0.90130716, 0.99911182, 1. ],\n", | |
| " [0.01299227, 0.17832554, 0.96272132, 1. ]])" | |
| ] | |
| }, | |
| "execution_count": 37, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "proba = model.predict_proba(X)\n", | |
| "proba = np.array([p[:, -1] for p in proba]).T\n", | |
| "proba" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "bound-replica", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1.9834426 , 2.09996098, 1.0818835 , 1.19870506, 1.91509632])" | |
| ] | |
| }, | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "preds = np.zeros_like(proba)\n", | |
| "preds[:, 3] = proba[:, 3] - proba[:, 2]\n", | |
| "preds[:, 2] = proba[:, 2] - proba[:, 1]\n", | |
| "preds[:, 1] = proba[:, 1] - proba[:, 0]\n", | |
| "preds[:, 0] = proba[:, 0]\n", | |
| "\n", | |
| "preds = np.sum(preds * np.arange(4), axis=1)\n", | |
| "preds[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "blond-headquarters", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(3.0, 2.659835580785948),\n", | |
| " (3.0, 2.618482471782334),\n", | |
| " (3.0, 2.6177394743764864),\n", | |
| " (3.0, 2.6154220385465967),\n", | |
| " (3.0, 2.6091923047850396),\n", | |
| " (3.0, 2.6071536935013717),\n", | |
| " (3.0, 2.605864341426174),\n", | |
| " (3.0, 2.6048851218042937),\n", | |
| " (3.0, 2.603226828313331),\n", | |
| " (3.0, 2.603226828313331)]" | |
| ] | |
| }, | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(zip(y, preds), key=lambda t: t[1], reverse=True)[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "fabulous-expert", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(0.0, 0.38541689664503054),\n", | |
| " (0.0, 0.3794059606745688),\n", | |
| " (0.0, 0.3778614518262179),\n", | |
| " (0.0, 0.3766454173898445),\n", | |
| " (0.0, 0.37562909145854884),\n", | |
| " (0.0, 0.37473664556868047),\n", | |
| " (0.0, 0.3728656881162348),\n", | |
| " (0.0, 0.3674908811208747),\n", | |
| " (0.0, 0.3643410653153032),\n", | |
| " (0.0, 0.3609506085625409)]" | |
| ] | |
| }, | |
| "execution_count": 40, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(zip(y, preds), key=lambda t: t[1], reverse=True)[-10:]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "swiss-photograph", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9937158387167305" | |
| ] | |
| }, | |
| "execution_count": 41, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(y.reshape(1, -1), preds.reshape(1, -1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "compatible-robert", | |
| "metadata": {}, | |
| "source": [ | |
| "## Pairwise approaches\n", | |
| "\n", | |
| "Essentially train a binary classifier on differences of pairs of items.\n", | |
| "\n", | |
| "Earlier pairwise ranking SVMs work with a transformation of pairs.\n", | |
| "\n", | |
| "$$\n", | |
| "X_{pairij} = X_i - X_j \\\\\n", | |
| "y_{pairij} = \\left\\{\\begin{array}{lr}\n", | |
| " -1, & \\text{if } y_i \\lt y_j \\\\\n", | |
| " 0, & \\text{if } y_i = y_j \\\\\n", | |
| " 1, & \\text{if } y_i \\gt y_j\n", | |
| " \\end{array}\\right\\}\n", | |
| "$$\n", | |
| "\n", | |
| "Since each pair appears twice you only really need one combination.\n", | |
| "- Good overview of earlier SVM approaches: https://doi.org/10.1007/978-3-642-01307-2_39" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "efficient-language", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Our larger dataset is a bit too big to do all combinations, so take a subset.\n", | |
| "X_sub = X[:1000]\n", | |
| "y_sub = y[:1000]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "explicit-ireland", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from itertools import combinations\n", | |
| "\n", | |
| "X_pairs = []\n", | |
| "y_pairs = []\n", | |
| "\n", | |
| "for i, j in combinations(range(X_sub.shape[0]), 2):\n", | |
| " if y_sub[i] > y_sub[j]:\n", | |
| " y_pairij = 1\n", | |
| " elif y_sub[i] < y_sub[j]:\n", | |
| " y_pairij = 0\n", | |
| " else:\n", | |
| " y_pairij = 0.5\n", | |
| "\n", | |
| " X_pairij = X_sub[i] - X_sub[j]\n", | |
| "\n", | |
| " X_pairs.append(X_pairij)\n", | |
| " y_pairs.append(y_pairij)\n", | |
| "\n", | |
| "X_pairs = np.array(X_pairs)\n", | |
| "y_pairs = np.array(y_pairs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "nearby-cleaner", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[-0.34927993, -0.42908207],\n", | |
| " [ 1.39590954, 0.58946327],\n", | |
| " [ 0.80515982, 1.42372707],\n", | |
| " ...,\n", | |
| " [-0.8909495 , -0.06959686],\n", | |
| " [ 1.33606144, -0.54053795],\n", | |
| " [ 2.22701094, -0.47094109]])" | |
| ] | |
| }, | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_pairs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "speaking-opposition", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.5, 1. , 1. , ..., 0.5, 1. , 1. ])" | |
| ] | |
| }, | |
| "execution_count": 45, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y_pairs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "answering-roulette", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def log_loss(true, pred):\n", | |
| " return ((true * np.log(pred)) + ((1 - true) * np.log(1 - pred))).sum() / -true.shape[0]\n", | |
| "\n", | |
| "def sigmoid(pred):\n", | |
| " return 1 / (1 + np.exp(-pred))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "developing-rebound", | |
| "metadata": {}, | |
| "source": [ | |
| "train a quicky gradient descent model." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "frank-marble", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss: 0.4487852673800323\n", | |
| "weights: [1.19453351 0.38415072]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learning_rate = 0.1\n", | |
| "nepoch = 100\n", | |
| "\n", | |
| "w = np.random.normal(scale=0.01, size=2).reshape(1, 2)\n", | |
| "for i in range(nepoch):\n", | |
| " preds = sigmoid(w.dot(X_pairs.T))[0]\n", | |
| " gradient = np.sum((y_pairs - preds).reshape(-1, 1) * X_pairs, axis=0).reshape(1, 2) / X_pairs.shape[0]\n", | |
| " w += learning_rate * gradient\n", | |
| "\n", | |
| " preds = sigmoid(w.dot(X_pairs.T))[0]\n", | |
| " loss = log_loss(y_pairs, preds)\n", | |
| "\n", | |
| "print(\"loss:\", loss)\n", | |
| "print(\"weights:\", w[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "manufactured-seller", | |
| "metadata": {}, | |
| "source": [ | |
| "We can now use these weights to perform pairwise order comparisons using your favourite sorting algorithm." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "distant-updating", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Obs(object):\n", | |
| "\n", | |
| " def __init__(self, arr, weights):\n", | |
| " self.arr = arr\n", | |
| " self.weights = weights\n", | |
| " return\n", | |
| "\n", | |
| " def __repr__(self):\n", | |
| " arr = [round(x, ndigits=2) for x in self.arr]\n", | |
| " weights = [round(x, ndigits=2) for x in self.weights]\n", | |
| " return f\"Obs(arr={arr}, weights={weights})\"\n", | |
| "\n", | |
| " def predict(self, other):\n", | |
| " return sigmoid(self.weights.reshape(1, 2).dot(self.arr.reshape(2, 1) - other.arr.reshape(2, 1)))[0, 0]\n", | |
| "\n", | |
| " def __gt__(self, other):\n", | |
| " return self.predict(other) > 0.5\n", | |
| "\n", | |
| " def __lt__(self, other):\n", | |
| " return self.predict(other) < 0.5\n", | |
| "\n", | |
| " def __eq__(self, other):\n", | |
| " return self.arr == other.arr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "continent-communications", | |
| "metadata": {}, | |
| "source": [ | |
| "Pythons sorting method will access the `__gt__` method to decide the order, which calls our predict function." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "raising-fighter", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[Obs(arr=[-3.8, -1.3], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-4.11, -0.11], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.36, -1.32], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.1, -2.05], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.61, -0.45], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.63, -0.34], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.29, -1.28], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.9, 0.81], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-2.61, -2.25], weights=[1.19, 0.38]),\n", | |
| " Obs(arr=[-3.15, -0.48], weights=[1.19, 0.38])]" | |
| ] | |
| }, | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "observations = [Obs(xi, w[0]) for xi in X]\n", | |
| "sorted(observations)[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "modern-console", | |
| "metadata": {}, | |
| "source": [ | |
| "I'll combine it with the ranks to check if it seems reasonable." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "presidential-budapest", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(3.0, Obs(arr=[4.03, 0.57], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[3.32, 0.27], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[3.06, 0.92], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[3.31, 0.09], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[2.68, 2.0], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[2.77, 1.53], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[3.14, 0.29], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[2.93, 0.67], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[2.87, 0.82], weights=[1.19, 0.38])),\n", | |
| " (3.0, Obs(arr=[2.67, 1.31], weights=[1.19, 0.38]))]" | |
| ] | |
| }, | |
| "execution_count": 50, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "observations = sorted(zip(y, observations), key=lambda t: t[1], reverse=True)\n", | |
| "observations[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "hundred-sacrifice", | |
| "metadata": {}, | |
| "source": [ | |
| "Remember that sorting like this will usually be $O = n\\times log_2(n)$. A complex function might take a long time to evaluate.\n", | |
| "\n", | |
| "For a strictly linear function, you can use the model directly on the features (rather than the space of differences)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "imposed-virtue", | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(3.0, 5.026986168572513),\n", | |
| " (3.0, 4.064973208581116),\n", | |
| " (3.0, 4.004872670382963),\n", | |
| " (3.0, 3.9923592281037013),\n", | |
| " (3.0, 3.9668530791630476),\n", | |
| " (3.0, 3.899361741345462),\n", | |
| " (3.0, 3.864273210501915),\n", | |
| " (3.0, 3.7553321255701824),\n", | |
| " (3.0, 3.7361629846308007),\n", | |
| " (3.0, 3.68819389908362)]" | |
| ] | |
| }, | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(zip(y, w.dot(X.T)[0]), key=lambda t: t[1], reverse=True)[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "imperial-hammer", | |
| "metadata": {}, | |
| "source": [ | |
| "### RankNet\n", | |
| "\n", | |
| "https://www.microsoft.com/en-us/research/wp-content/uploads/2005/08/icml_ranking.pdf\n", | |
| "\n", | |
| "Basically it's a siamese network with logistic activation and cross entropy loss.\n", | |
| "More or less as above, but moving where we take the difference so that we speed up evaluation time.\n", | |
| "RankBoost did the same thing but with a fancy schmancy loss function.\n", | |
| "\n", | |
| "They also provide a more thorough analysis of why mapping to a probability of $x_i > x_j$ is useful and consistent." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "precise-holder", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from itertools import combinations\n", | |
| "\n", | |
| "X_pairs = []\n", | |
| "y_pairs = []\n", | |
| "\n", | |
| "for i, j in combinations(range(X_sub.shape[0]), 2):\n", | |
| " if y_sub[i] > y_sub[j]:\n", | |
| " y_pairij = 1\n", | |
| " else:\n", | |
| " continue\n", | |
| " #elif y[i] < y[j]:\n", | |
| " # y_pairij = 0\n", | |
| " #else:\n", | |
| " # y_pairij = 0.5\n", | |
| "\n", | |
| " X_pairs.append((X_sub[i], X_sub[j]))\n", | |
| " y_pairs.append(y_pairij)\n", | |
| "\n", | |
| "X_pairs = np.array(X_pairs)\n", | |
| "y_pairs = np.array(y_pairs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "educational-organization", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ],\n", | |
| " [0.82418808, 0.479966 ]])" | |
| ] | |
| }, | |
| "execution_count": 53, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_pairs[:10, 0, :]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "still-encyclopedia", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[-0.57172145, -0.10949727],\n", | |
| " [ 0.01902826, -0.94376106],\n", | |
| " [-0.29873262, -0.46058737],\n", | |
| " [-1.08879299, -0.57577075],\n", | |
| " [-1.68290077, 0.22918525],\n", | |
| " [-1.75662522, 0.84463262],\n", | |
| " [-0.6564723 , -0.2015057 ],\n", | |
| " [-0.70061583, 0.68713795],\n", | |
| " [-0.02607576, -0.82975832],\n", | |
| " [-0.61130127, -0.8217515 ]])" | |
| ] | |
| }, | |
| "execution_count": 54, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_pairs[:10, 1, :]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 55, | |
| "id": "respiratory-replica", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "from tensorflow.keras import Sequential\n", | |
| "from tensorflow.keras.layers import Dense\n", | |
| "\n", | |
| "from tensorflow.keras.activations import sigmoid as tfsigmoid\n", | |
| "from tensorflow.keras.losses import BinaryCrossentropy\n", | |
| "from tensorflow.keras.optimizers import Adam" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "id": "minus-compatibility", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tf.Tensor(0.11192029, shape=(), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "optimizer = Adam(learning_rate=0.1)\n", | |
| "loss_fn = BinaryCrossentropy(from_logits=True)\n", | |
| "\n", | |
| "model = Dense(1, activation=\"linear\", use_bias=False)\n", | |
| "\n", | |
| "n_epoch = 100\n", | |
| "for _ in range(n_epoch):\n", | |
| " with tf.GradientTape() as tape:\n", | |
| " s1 = model(X_pairs[:, 0, :])\n", | |
| " s2 = model(X_pairs[:, 1, :])\n", | |
| " logits = tfsigmoid(s1 - s2)\n", | |
| " loss_value = loss_fn(y_pairs.reshape(-1, 1), logits)\n", | |
| "\n", | |
| " grads = tape.gradient(loss_value, model.trainable_variables)\n", | |
| " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", | |
| "\n", | |
| "print(loss_value)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "religious-blond", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[array([[3.30593 ],\n", | |
| " [1.098215]], dtype=float32)]" | |
| ] | |
| }, | |
| "execution_count": 57, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.get_weights()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "id": "random-therapist", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(3.0, 13.932344),\n", | |
| " (3.0, 11.259557),\n", | |
| " (3.0, 11.116082),\n", | |
| " (3.0, 11.052165),\n", | |
| " (3.0, 11.048449),\n", | |
| " (3.0, 10.845408),\n", | |
| " (3.0, 10.704596),\n", | |
| " (3.0, 10.416609),\n", | |
| " (3.0, 10.368637),\n", | |
| " (3.0, 10.253225)]" | |
| ] | |
| }, | |
| "execution_count": 58, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y_pred = model(X).numpy()\n", | |
| "\n", | |
| "sorted(zip(y, y_pred[:, 0]), key=lambda t: t[1], reverse=True)[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "executed-disney", | |
| "metadata": {}, | |
| "source": [ | |
| "There are newer pairwise ranking methods that's popular in the recommender system world using the WARP loss:\n", | |
| "\n", | |
| "- http://www.gabormelli.com/RKB/Weighted_Approximately_Ranked_Pairwise_(WARP)_Ranking_Loss\n", | |
| "- https://www.hongliangjie.com/2012/08/24/weighted-approximately-ranked-pairwise-loss-warp/\n", | |
| "\n", | |
| "I haven't looked at it yet, but apparently it's quite good.\n", | |
| "\n", | |
| "Triplet losses might also be interesting for people interested in metric or embedding learning.\n", | |
| "\n", | |
| "- Very successful for learning semantic representations of data.\n", | |
| "- E.g. dense embeddings of very sparse, multi-class problems.\n", | |
| "- https://arxiv.org/abs/1412.6622\n", | |
| "\n", | |
| "\n", | |
| "Define an anchor term $a$ which is your item of interest, then sample a positive $p$ and negative $n$ example from the data.\n", | |
| "$p$ should share a class annotation with $a$ and $n$ should not share any annotations.\n", | |
| "\n", | |
| "The idea is to learn a distance function $f$ that maximises the distance between $a$ and $n$, and minimises distance between $a$ and $p$.\n", | |
| "\n", | |
| "$$\n", | |
| "Loss = max(f(a, p) − f(a, n) + \\text{margin}, 0)\n", | |
| "$$\n", | |
| "\n", | |
| "Similar in spirit to SVM classification (pushing two opposing classses together).\n", | |
| "\n", | |
| "\n", | |
| "## Listwise ranking\n", | |
| "\n", | |
| "In many applications you're really only interested in the top few hits (e.g. search engines).\n", | |
| "Pairwise approaches give equal importance for all pairs, but we want to emphasise the relevant items during training.\n", | |
| "\n", | |
| "Next step was to optimise list-wise ranking, loss function is NDCG (or similar) of the whole list.\n", | |
| "Means you require multiple lists, e.g. many query terms and a ranked list of results.\n", | |
| "\n", | |
| "So the idea is to minimise some function:\n", | |
| "\n", | |
| "$$\n", | |
| "\\frac{1}{m} \\sum_{i=1}^{m}loss(sort(f(x_1^i), ...), y^i)\n", | |
| "$$\n", | |
| "\n", | |
| "\n", | |
| "### LambdaRank\n", | |
| "\n", | |
| "Extended RankNet by optimising directly for NDCG (or any other list-wise metric).\n", | |
| "\n", | |
| "- https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/lambdarank.pdf\n", | |
| "- https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/SoftRankWsdm08Submitted.pdf\n", | |
| "- https://www.cs.cmu.edu/~pinard/Papers/sigirfp092-donmez.pdf\n", | |
| "\n", | |
| "\n", | |
| "- list-ranking statistics are highly discontinuous/non-smooth/non-differentiable.\n", | |
| "- Key observation here is that you don't need the results of the loss function, just the gradient.\n", | |
| "- So you can add your own slope to the gradient for update and use that instead.\n", | |
| "\n", | |
| "\n", | |
| "From RankNet given our loss function $C$ and weights $w$, predictions $s_i = f(x_i)$ and $S_{ij} \\in \\{-1, 0, 1\\}$.\n", | |
| "\n", | |
| "You can factor out the gradient given by RankNet to be:\n", | |
| "\n", | |
| "$$\n", | |
| "\\frac{\\partial C}{\\partial w_k} = \\lambda_{ij} \\left(\\frac{\\partial s_i}{\\partial w_k} - \\frac{\\partial s_j}{\\partial w_k}\\right)\n", | |
| "$$\n", | |
| "\n", | |
| "where\n", | |
| "\n", | |
| "$$\n", | |
| "\\lambda_{ij} = \\frac{\\partial C(s_i - s_j)}{\\partial s_i} = \\sigma \\left(\\frac{1}{2}(1 - S_{ij}) - \\frac{1}{1 + e^{\\sigma(s_i - s_j)}}\\right)\n", | |
| "$$\n", | |
| "\n", | |
| "\n", | |
| "In lambdaRank:\n", | |
| "1. predict the labels for each sample\n", | |
| "2. for each pair i, j where $y_i > y_j$ (i.e. $S_{ij} = 1$) you calculate the effect that switching their predicted place in the sorted list has on NDCG@N ($\\Delta \\text{NDCG}_{ij}$).\n", | |
| "3. You replace the derivative of the loss w.r.t. $s_i$ ($\\lambda_{ij}$) with the following:\n", | |
| "\n", | |
| "\n", | |
| "$$\n", | |
| "\\lambda_{ij} = \\frac{-\\sigma}{1 + e^{\\sigma(s_i - s_j)}} \\left|\\Delta \\text{NDCG}_{ij}\\right|\n", | |
| "$$\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "### LambdaMART\n", | |
| "\n", | |
| "Does basically the same thing as lambdaRank, except the $\\lambda_{ij}$ is used when calculating the gradients for adding new boosting trees instead of gradient descent.\n", | |
| "\n", | |
| "Highly recommend \"From RankNet to LamdaRank to LambdaMART: an overview\" here: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf\n", | |
| "\n", | |
| "Quite a clear description of the conceptual process from the Microsoft people that did this work. And reasonably complete equations and descriptions.\n", | |
| "\n", | |
| "\n", | |
| "XGBoost has a good but poorly documented implementation of LamdaMART.\n", | |
| "\n", | |
| "LightGBM seems to have \"lambdarank\" and another model XENDCG (https://arxiv.org/abs/1911.09798; related to listNet)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "id": "champion-jewel", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import xgboost as xgb" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "id": "greek-sherman", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.datasets import dump_svmlight_file\n", | |
| "dump_svmlight_file(X[:9000], y[:9000], \"train.svmlight\")\n", | |
| "\n", | |
| "with open(\"train.svmlight.group\", \"w\") as handle:\n", | |
| " print(\"\\n\".join(str(100) for _ in range(90)), file=handle)\n", | |
| " \n", | |
| "dump_svmlight_file(X[9000:], y[9000:], \"test.svmlight\")\n", | |
| "\n", | |
| "with open(\"test.svmlight.group\", \"w\") as handle:\n", | |
| " print(\"\\n\".join(str(100) for _ in range(10)), file=handle)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 61, | |
| "id": "statutory-malaysia", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[02:53:25] 9000x2 matrix with 18000 entries loaded from train.svmlight\n", | |
| "[02:53:25] 90 groups are loaded from train.svmlight.group\n", | |
| "[02:53:25] 1000x2 matrix with 2000 entries loaded from test.svmlight\n", | |
| "[02:53:25] 10 groups are loaded from test.svmlight.group\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "train = xgb.DMatrix(\"train.svmlight\")\n", | |
| "test = xgb.DMatrix(\"test.svmlight\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 62, | |
| "id": "stone-organ", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "params = {\n", | |
| " 'verbosity': 0,\n", | |
| " 'max_depth': 4,\n", | |
| " 'objective': 'rank:ndcg', # can also do rank:pairwise or rank:map\n", | |
| " 'eval_metric': [\n", | |
| " 'ndcg',\n", | |
| " 'ndcg@10',\n", | |
| " 'map',\n", | |
| " 'map@10',\n", | |
| " ]\n", | |
| "}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "id": "circular-personal", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0]\ttest-ndcg:0.905414\ttest-ndcg@10:0.681215\ttest-map:0.987547\ttest-map@10:0.106358\n", | |
| "[1]\ttest-ndcg:0.930646\ttest-ndcg@10:0.779804\ttest-map:0.98844\ttest-map@10:0.106358\n", | |
| "[2]\ttest-ndcg:0.931475\ttest-ndcg@10:0.781137\ttest-map:0.988497\ttest-map@10:0.106358\n", | |
| "[3]\ttest-ndcg:0.960732\ttest-ndcg@10:0.903813\ttest-map:0.988633\ttest-map@10:0.106358\n", | |
| "[4]\ttest-ndcg:0.960481\ttest-ndcg@10:0.910845\ttest-map:0.988403\ttest-map@10:0.106358\n", | |
| "[5]\ttest-ndcg:0.960985\ttest-ndcg@10:0.911291\ttest-map:0.988541\ttest-map@10:0.106358\n", | |
| "[6]\ttest-ndcg:0.964712\ttest-ndcg@10:0.920337\ttest-map:0.988581\ttest-map@10:0.106358\n", | |
| "[7]\ttest-ndcg:0.965182\ttest-ndcg@10:0.920337\ttest-map:0.988431\ttest-map@10:0.106358\n", | |
| "[8]\ttest-ndcg:0.97069\ttest-ndcg@10:0.936699\ttest-map:0.988225\ttest-map@10:0.106358\n", | |
| "[9]\ttest-ndcg:0.971002\ttest-ndcg@10:0.937074\ttest-map:0.988226\ttest-map@10:0.106358\n", | |
| "[10]\ttest-ndcg:0.97941\ttest-ndcg@10:0.955805\ttest-map:0.988309\ttest-map@10:0.106358\n", | |
| "[11]\ttest-ndcg:0.984046\ttest-ndcg@10:0.966782\ttest-map:0.988217\ttest-map@10:0.106358\n", | |
| "[12]\ttest-ndcg:0.984086\ttest-ndcg@10:0.966782\ttest-map:0.988282\ttest-map@10:0.106358\n", | |
| "[13]\ttest-ndcg:0.984272\ttest-ndcg@10:0.966953\ttest-map:0.988383\ttest-map@10:0.106358\n", | |
| "[14]\ttest-ndcg:0.984449\ttest-ndcg@10:0.967277\ttest-map:0.98833\ttest-map@10:0.106358\n", | |
| "[15]\ttest-ndcg:0.984635\ttest-ndcg@10:0.967482\ttest-map:0.988316\ttest-map@10:0.106358\n", | |
| "[16]\ttest-ndcg:0.984704\ttest-ndcg@10:0.967482\ttest-map:0.988686\ttest-map@10:0.106358\n", | |
| "[17]\ttest-ndcg:0.984704\ttest-ndcg@10:0.967405\ttest-map:0.988562\ttest-map@10:0.106358\n", | |
| "[18]\ttest-ndcg:0.984779\ttest-ndcg@10:0.967106\ttest-map:0.988665\ttest-map@10:0.106358\n", | |
| "[19]\ttest-ndcg:0.984641\ttest-ndcg@10:0.966782\ttest-map:0.988944\ttest-map@10:0.106358\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model = xgb.train(params, train, evals=[(test, \"test\")], num_boost_round=20, verbose_eval=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 64, | |
| "id": "spoken-space", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ytest = y[9000:].reshape(-1, 100)\n", | |
| "preds = model.predict(test).reshape(-1, 100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 65, | |
| "id": "verbal-amber", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9889175984677776" | |
| ] | |
| }, | |
| "execution_count": 65, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(ytest, preds)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 66, | |
| "id": "scenic-facility", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9943802083769608" | |
| ] | |
| }, | |
| "execution_count": 66, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "normalised_dcg(ytest, preds, original=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 67, | |
| "id": "comic-vegetarian", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9913811457214813" | |
| ] | |
| }, | |
| "execution_count": 67, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ndcg_score(ytest, preds)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "continued-parade", | |
| "metadata": {}, | |
| "source": [ | |
| "None of the NDCG scores are the same :S" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "atomic-draft", | |
| "metadata": {}, | |
| "source": [ | |
| "### Beyond LambdaRank\n", | |
| "\n", | |
| "Soon after lambdaRank people started looking for \"smooth\" (i.e. differentiable) loss functions approximating ranking/IR metrics.\n", | |
| "\n", | |
| "- softRank\n", | |
| " - https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/SoftRankWsdm08Submitted.pdf\n", | |
| "- listNet\n", | |
| " - https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-2007-40.pdf\n", | |
| " - Attempts to define a probability for a particular list permutation (over all possible permutations). Like a long chain of conditional probabilities.\n", | |
| " - Heuristic approximation \"top 1\". Sum of permutation probabilities for permutations where the current observation is highest ranked.\n", | |
| "- listMLE\n", | |
| " - https://dl.acm.org/doi/abs/10.1145/1390156.1390306\n", | |
| " - Uses a likelihood loss, which they are a bit more formal about defining that listNet was.\n", | |
| "- Bayesian personalised ranking (BPR)\n", | |
| " - Popular in recommenders\n", | |
| " - https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf\n", | |
| "\n", | |
| "\n", | |
| "### ListMLE\n", | |
| " \n", | |
| "listMLE likelihood loss function where $g$ is your function resulting in ranking scores.\n", | |
| "\n", | |
| "$$\n", | |
| "l(g(x), y) = -log P(y | x; g) \\\\\n", | |
| "P(y | x; g) = \\prod_{i=1}^{n} \\frac{exp(g(x_{y(i)}))}{\\sum_{k=i}^{n}exp(g(x_{y(k)}))}\n", | |
| "$$\n", | |
| "\n", | |
| "Where $y(i)$ is the index of the object ranked at position $i$.\n", | |
| "So $x_{y(i)}$ retrieves that document.\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 68, | |
| "id": "postal-syndicate", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@tf.function\n", | |
| "def list_mle_loss(true, pred):\n", | |
| " order = tf.argsort(true, direction=\"DESCENDING\")\n", | |
| " pred_sort = tf.gather(pred, order)\n", | |
| "\n", | |
| " pred_sort -= tf.reduce_max(pred_sort, axis=-1)\n", | |
| " sums = tf.cumsum(tf.exp(pred_sort), axis=-1, reverse=True)\n", | |
| " sums = -1 * (pred_sort - tf.math.log(sums))\n", | |
| " return tf.reduce_sum(sums, keepdims=True, axis=-1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 69, | |
| "id": "killing-bhutan", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1. , 0.98989899, 0.97979798, 0.96969697, 0.95959596,\n", | |
| " 0.94949495, 0.93939394, 0.92929293, 0.91919192, 0.90909091,\n", | |
| " 0.8989899 , 0.88888889, 0.87878788, 0.86868687, 0.85858586,\n", | |
| " 0.84848485, 0.83838384, 0.82828283, 0.81818182, 0.80808081,\n", | |
| " 0.7979798 , 0.78787879, 0.77777778, 0.76767677, 0.75757576,\n", | |
| " 0.74747475, 0.73737374, 0.72727273, 0.71717172, 0.70707071,\n", | |
| " 0.6969697 , 0.68686869, 0.67676768, 0.66666667, 0.65656566,\n", | |
| " 0.64646465, 0.63636364, 0.62626263, 0.61616162, 0.60606061,\n", | |
| " 0.5959596 , 0.58585859, 0.57575758, 0.56565657, 0.55555556,\n", | |
| " 0.54545455, 0.53535354, 0.52525253, 0.51515152, 0.50505051,\n", | |
| " 0.49494949, 0.48484848, 0.47474747, 0.46464646, 0.45454545,\n", | |
| " 0.44444444, 0.43434343, 0.42424242, 0.41414141, 0.4040404 ,\n", | |
| " 0.39393939, 0.38383838, 0.37373737, 0.36363636, 0.35353535,\n", | |
| " 0.34343434, 0.33333333, 0.32323232, 0.31313131, 0.3030303 ,\n", | |
| " 0.29292929, 0.28282828, 0.27272727, 0.26262626, 0.25252525,\n", | |
| " 0.24242424, 0.23232323, 0.22222222, 0.21212121, 0.2020202 ,\n", | |
| " 0.19191919, 0.18181818, 0.17171717, 0.16161616, 0.15151515,\n", | |
| " 0.14141414, 0.13131313, 0.12121212, 0.11111111, 0.1010101 ,\n", | |
| " 0.09090909, 0.08080808, 0.07070707, 0.06060606, 0.05050505,\n", | |
| " 0.04040404, 0.03030303, 0.02020202, 0.01010101, 0. ])" | |
| ] | |
| }, | |
| "execution_count": 69, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ysorted = np.arange(99, -1, step=-1).astype(float)\n", | |
| "ypred = np.linspace(1, 0, 100)\n", | |
| "ypred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "id": "annoying-ministry", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([340.17004064])" | |
| ] | |
| }, | |
| "execution_count": 70, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list_mle_loss(ysorted, ypred).numpy()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 71, | |
| "id": "analyzed-conservative", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([390.17004064])" | |
| ] | |
| }, | |
| "execution_count": 71, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list_mle_loss(ysorted, ypred[::-1]).numpy()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 72, | |
| "id": "minus-aspect", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0. , 0.98989899, 0.97979798, 0.96969697, 0.95959596,\n", | |
| " 0.94949495, 0.93939394, 0.92929293, 0.91919192, 0.90909091,\n", | |
| " 0.8989899 , 0.88888889, 0.87878788, 0.86868687, 0.85858586,\n", | |
| " 0.84848485, 0.83838384, 0.82828283, 0.81818182, 0.80808081,\n", | |
| " 0.7979798 , 0.78787879, 0.77777778, 0.76767677, 0.75757576,\n", | |
| " 0.74747475, 0.73737374, 0.72727273, 0.71717172, 0.70707071,\n", | |
| " 0.6969697 , 0.68686869, 0.67676768, 0.66666667, 0.65656566,\n", | |
| " 0.64646465, 0.63636364, 0.62626263, 0.61616162, 0.60606061,\n", | |
| " 0.5959596 , 0.58585859, 0.57575758, 0.56565657, 0.55555556,\n", | |
| " 0.54545455, 0.53535354, 0.52525253, 0.51515152, 0.50505051,\n", | |
| " 0.49494949, 0.48484848, 0.47474747, 0.46464646, 0.45454545,\n", | |
| " 0.44444444, 0.43434343, 0.42424242, 0.41414141, 0.4040404 ,\n", | |
| " 0.39393939, 0.38383838, 0.37373737, 0.36363636, 0.35353535,\n", | |
| " 0.34343434, 0.33333333, 0.32323232, 0.31313131, 0.3030303 ,\n", | |
| " 0.29292929, 0.28282828, 0.27272727, 0.26262626, 0.25252525,\n", | |
| " 0.24242424, 0.23232323, 0.22222222, 0.21212121, 0.2020202 ,\n", | |
| " 0.19191919, 0.18181818, 0.17171717, 0.16161616, 0.15151515,\n", | |
| " 0.14141414, 0.13131313, 0.12121212, 0.11111111, 0.1010101 ,\n", | |
| " 0.09090909, 0.08080808, 0.07070707, 0.06060606, 0.05050505,\n", | |
| " 0.04040404, 0.03030303, 0.02020202, 0.01010101, 1. ])" | |
| ] | |
| }, | |
| "execution_count": 72, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ypred_mut = ypred.copy()\n", | |
| "ypred_mut[0] = ypred[-1]\n", | |
| "ypred_mut[-1] = ypred[0]\n", | |
| "ypred_mut" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 73, | |
| "id": "italian-victorian", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([346.93373969])" | |
| ] | |
| }, | |
| "execution_count": 73, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list_mle_loss(ysorted, ypred_mut).numpy()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 74, | |
| "id": "declared-johns", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.98989899, 1. , 0.97979798, 0.96969697, 0.95959596,\n", | |
| " 0.94949495, 0.93939394, 0.92929293, 0.91919192, 0.90909091,\n", | |
| " 0.8989899 , 0.88888889, 0.87878788, 0.86868687, 0.85858586,\n", | |
| " 0.84848485, 0.83838384, 0.82828283, 0.81818182, 0.80808081,\n", | |
| " 0.7979798 , 0.78787879, 0.77777778, 0.76767677, 0.75757576,\n", | |
| " 0.74747475, 0.73737374, 0.72727273, 0.71717172, 0.70707071,\n", | |
| " 0.6969697 , 0.68686869, 0.67676768, 0.66666667, 0.65656566,\n", | |
| " 0.64646465, 0.63636364, 0.62626263, 0.61616162, 0.60606061,\n", | |
| " 0.5959596 , 0.58585859, 0.57575758, 0.56565657, 0.55555556,\n", | |
| " 0.54545455, 0.53535354, 0.52525253, 0.51515152, 0.50505051,\n", | |
| " 0.49494949, 0.48484848, 0.47474747, 0.46464646, 0.45454545,\n", | |
| " 0.44444444, 0.43434343, 0.42424242, 0.41414141, 0.4040404 ,\n", | |
| " 0.39393939, 0.38383838, 0.37373737, 0.36363636, 0.35353535,\n", | |
| " 0.34343434, 0.33333333, 0.32323232, 0.31313131, 0.3030303 ,\n", | |
| " 0.29292929, 0.28282828, 0.27272727, 0.26262626, 0.25252525,\n", | |
| " 0.24242424, 0.23232323, 0.22222222, 0.21212121, 0.2020202 ,\n", | |
| " 0.19191919, 0.18181818, 0.17171717, 0.16161616, 0.15151515,\n", | |
| " 0.14141414, 0.13131313, 0.12121212, 0.11111111, 0.1010101 ,\n", | |
| " 0.09090909, 0.08080808, 0.07070707, 0.06060606, 0.05050505,\n", | |
| " 0.04040404, 0.03030303, 0.02020202, 0.01010101, 0. ])" | |
| ] | |
| }, | |
| "execution_count": 74, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ypred_mut = ypred.copy()\n", | |
| "ypred_mut[0] = ypred[1]\n", | |
| "ypred_mut[1] = ypred[0]\n", | |
| "ypred_mut" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 75, | |
| "id": "alpha-traveler", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([340.17020204])" | |
| ] | |
| }, | |
| "execution_count": 75, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list_mle_loss(ysorted, ypred_mut).numpy()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 76, | |
| "id": "cooked-swing", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1. , 0.98989899, 0.97979798, 0.96969697, 0.95959596,\n", | |
| " 0.94949495, 0.93939394, 0.92929293, 0.91919192, 0.90909091,\n", | |
| " 0.8989899 , 0.88888889, 0.87878788, 0.86868687, 0.85858586,\n", | |
| " 0.84848485, 0.83838384, 0.82828283, 0.81818182, 0.80808081,\n", | |
| " 0.7979798 , 0.78787879, 0.77777778, 0.76767677, 0.75757576,\n", | |
| " 0.74747475, 0.73737374, 0.72727273, 0.71717172, 0.70707071,\n", | |
| " 0.6969697 , 0.68686869, 0.67676768, 0.66666667, 0.65656566,\n", | |
| " 0.64646465, 0.63636364, 0.62626263, 0.61616162, 0.60606061,\n", | |
| " 0.5959596 , 0.58585859, 0.57575758, 0.56565657, 0.55555556,\n", | |
| " 0.54545455, 0.53535354, 0.52525253, 0.51515152, 0.50505051,\n", | |
| " 0.49494949, 0.48484848, 0.47474747, 0.46464646, 0.45454545,\n", | |
| " 0.44444444, 0.43434343, 0.42424242, 0.41414141, 0.4040404 ,\n", | |
| " 0.39393939, 0.38383838, 0.37373737, 0.36363636, 0.35353535,\n", | |
| " 0.34343434, 0.33333333, 0.32323232, 0.31313131, 0.3030303 ,\n", | |
| " 0.29292929, 0.28282828, 0.27272727, 0.26262626, 0.25252525,\n", | |
| " 0.24242424, 0.23232323, 0.22222222, 0.21212121, 0.2020202 ,\n", | |
| " 0.19191919, 0.18181818, 0.17171717, 0.16161616, 0.15151515,\n", | |
| " 0.14141414, 0.13131313, 0.12121212, 0.11111111, 0.1010101 ,\n", | |
| " 0.09090909, 0.08080808, 0.07070707, 0.06060606, 0.05050505,\n", | |
| " 0.04040404, 0.03030303, 0.02020202, 0. , 0. ])" | |
| ] | |
| }, | |
| "execution_count": 76, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ypred_mut = ypred.copy()\n", | |
| "ypred_mut[-1] = ypred[-1]\n", | |
| "ypred_mut[-2] = ypred[-1]\n", | |
| "ypred_mut" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 77, | |
| "id": "outer-article", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([340.14205151])" | |
| ] | |
| }, | |
| "execution_count": 77, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list_mle_loss(ysorted, ypred_mut).numpy()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "psychological-optics", | |
| "metadata": {}, | |
| "source": [ | |
| "Lets try training a model!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 78, | |
| "id": "civil-belgium", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "Xsplit = list(np.split(X[:9000], np.arange(100, 9000, 100)))\n", | |
| "ysplit = list(np.split(y[:9000], np.arange(100, 9000, 100)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 79, | |
| "id": "bulgarian-argentina", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tf.keras.backend.clear_session()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "id": "outer-throat", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 [370.86136] 0.7423732225166043\n", | |
| "10 [354.9018] 0.9205463108498243\n", | |
| "20 [346.13977] 1.0\n", | |
| "30 [346.25073] 0.9977990446115134\n", | |
| "40 [341.0366] 0.9972605686529491\n", | |
| "50 [337.79468] 1.0\n", | |
| "60 [341.48578] 0.9999826230597305\n", | |
| "70 [339.2522] 0.9999943836284894\n", | |
| "80 [339.8788] 0.999715882466915\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "optimizer = Adam(learning_rate=0.1)\n", | |
| "loss_fn = list_mle_loss\n", | |
| "\n", | |
| "model = Dense(1, activation=\"linear\", use_bias=False)\n", | |
| "\n", | |
| "for i, (Xi, yi) in enumerate(zip(Xsplit, ysplit)):\n", | |
| " with tf.GradientTape() as tape:\n", | |
| " pred = model(Xi)\n", | |
| " logits = tfsigmoid(pred)\n", | |
| " loss_value = loss_fn(yi, tf.reshape(logits, -1))\n", | |
| "\n", | |
| " grads = tape.gradient(loss_value, model.trainable_variables)\n", | |
| " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", | |
| " #print(sorted(zip(pred.numpy().reshape(-1), yi), key=lambda t: t[0]))\n", | |
| "\n", | |
| " if (i % 10) == 0:\n", | |
| " print(i, loss_value.numpy(), normalised_dcg(yi.reshape(1, -1), pred.numpy().reshape(1, -1)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 81, | |
| "id": "congressional-enough", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "Xsplit = list(np.split(X[9000:], np.arange(100, 1000, 100)))\n", | |
| "ysplit = list(np.split(y[9000:], np.arange(100, 1000, 100)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 82, | |
| "id": "pending-corpus", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]" | |
| ] | |
| }, | |
| "execution_count": 82, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "preds = []\n", | |
| "ndcgs = []\n", | |
| "for Xi, yi in zip(Xsplit, ysplit):\n", | |
| " p = model(Xi)\n", | |
| " preds.append(p.numpy().flatten())\n", | |
| " ndcgs.append(normalised_dcg(yi.reshape(1, -1), p.numpy().reshape(1, -1)))\n", | |
| "\n", | |
| "preds = pd.DataFrame({\"true\": ysplit[0], \"preds\": preds[0]})\n", | |
| "\n", | |
| "ndcgs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 83, | |
| "id": "consolidated-surfing", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>preds</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>30</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>-8.204016</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>87</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>-7.504719</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>-7.269742</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>88</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>-6.528606</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>82</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>-6.395883</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>89</th>\n", | |
| " <td>2.0</td>\n", | |
| " <td>6.246413</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>33</th>\n", | |
| " <td>2.0</td>\n", | |
| " <td>6.354847</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>16</th>\n", | |
| " <td>2.0</td>\n", | |
| " <td>7.383247</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>29</th>\n", | |
| " <td>2.0</td>\n", | |
| " <td>7.794843</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>99</th>\n", | |
| " <td>3.0</td>\n", | |
| " <td>11.692219</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>100 rows × 2 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true preds\n", | |
| "30 0.0 -8.204016\n", | |
| "87 0.0 -7.504719\n", | |
| "0 0.0 -7.269742\n", | |
| "88 0.0 -6.528606\n", | |
| "82 0.0 -6.395883\n", | |
| ".. ... ...\n", | |
| "89 2.0 6.246413\n", | |
| "33 2.0 6.354847\n", | |
| "16 2.0 7.383247\n", | |
| "29 2.0 7.794843\n", | |
| "99 3.0 11.692219\n", | |
| "\n", | |
| "[100 rows x 2 columns]" | |
| ] | |
| }, | |
| "execution_count": 83, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "preds.sort_values(\"preds\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 84, | |
| "id": "micro-kitchen", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[array([[3.8863764],\n", | |
| " [1.0636641]], dtype=float32)]" | |
| ] | |
| }, | |
| "execution_count": 84, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.get_weights()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "rotary-baseline", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "necessary-manitoba", | |
| "metadata": {}, | |
| "source": [ | |
| "Some libraries for doing LTR.\n", | |
| "\n", | |
| "- tf-ranking https://github.com/tensorflow/ranking\n", | |
| "- allrank (pytorch) https://github.com/allegro/allRank/tree/master/allrank\n", | |
| "- lightGBM https://lightgbm.readthedocs.io/en/latest/index.html\n", | |
| "- XGBoost https://xgboost.readthedocs.io/en/latest/ (doesn't work with the sk-learn estimator API).\n", | |
| "- SVMRank https://www.cs.cornell.edu/people/tj/svm_light/svm_rank.html https://www.cs.cornell.edu/people/tj/svm_light/svm_proprank.html\n", | |
| "- https://github.com/hpclab/quickrank\n", | |
| "- A lot of options dedicated to search engine stuff e.g. https://elasticsearch-learning-to-rank.readthedocs.io/en/latest/\n", | |
| "\n", | |
| "\n", | |
| "There might be more complete libraries out there for recommenders that incorporate some LTR features. \n", | |
| "\n", | |
| "- https://github.com/maciejkula/spotlight\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "Most of the current focus is around correcting for biased feedback (e.g. clicks) and online learning.\n", | |
| "\n", | |
| "Position bias means higher ranked items more likely to be clicked.\n", | |
| "Selection bias means that some relevant documents might never reach the stage of ranking (i.e. after the initial document retrieval stage).\n", | |
| "\n", | |
| "- Good overview of biased sampling and ML overall https://dl.acm.org/doi/pdf/10.1145/1015330.1015425\n", | |
| "\n", | |
| "- Click models. Attempt rank to maximise number of clicks. Refers to \"click bandits\" or \"multi-armed bandit learners\" which is fun. http://proceedings.mlr.press/v70/zoghi17a/zoghi17a.pdf \n", | |
| "- Inverse propensity scores https://arxiv.org/abs/1608.04468\n", | |
| "- Heckman 2 stage & predict ranking as counterfactual i.e. predict probability that the document have been clicked had it been presented to user (https://dl.acm.org/doi/abs/10.1145/3366423.3380255\n", | |
| "- Online learning and bias correction https://dl.acm.org/doi/pdf/10.1145/3269206.3271686\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "velvet-halifax", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "virgin-aberdeen", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "announced-shopping", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "united-programming", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "functioning-parliament", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "fifth-value", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "younger-marshall", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "applied-bosnia", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "native-chosen", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "criminal-liberia", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "hidden-noise", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "persistent-fisher", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cutting-entity", | |
| "metadata": {}, | |
| "source": [ | |
| "### I had this half baked idea about some relationship between pairwise learning approaches and matched samples." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 85, | |
| "id": "technical-findings", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rng = default_rng()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 86, | |
| "id": "banned-seattle", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0, 0, 1, ..., 1, 0, 0],\n", | |
| " [0, 1, 0, ..., 0, 1, 0],\n", | |
| " [0, 1, 0, ..., 0, 0, 0],\n", | |
| " ...,\n", | |
| " [0, 1, 0, ..., 0, 1, 0],\n", | |
| " [0, 0, 1, ..., 0, 0, 1],\n", | |
| " [1, 0, 1, ..., 1, 1, 0]])" | |
| ] | |
| }, | |
| "execution_count": 86, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nloci = 1000\n", | |
| "nsamples = 101\n", | |
| "freqs_pop1 = rng.uniform(size=nloci)\n", | |
| "\n", | |
| "pop1 = np.apply_along_axis(\n", | |
| " lambda p: rng.choice([1, 0], size=nsamples, p=[p[0], 1 - p[0]]),\n", | |
| " 0,\n", | |
| " freqs_pop1.reshape(1, -1)\n", | |
| ")\n", | |
| "\n", | |
| "pop1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 87, | |
| "id": "vocational-diversity", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.35643564, 0.36633663, 0.34653465, 0.92079208, 0.72277228])" | |
| ] | |
| }, | |
| "execution_count": 87, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pop1.mean(axis=0)[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 88, | |
| "id": "small-beaver", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.42173889, 0.28234691, 0.32596702, 0.87615543, 0.7065007 ])" | |
| ] | |
| }, | |
| "execution_count": 88, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "freqs_pop1[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 89, | |
| "id": "beginning-occupation", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.8, 0.7, 0.2, 0.3, 0.7])" | |
| ] | |
| }, | |
| "execution_count": 89, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nsamples2 = int(nsamples / 10)\n", | |
| "freqs_pop2 = rng.uniform(size=nloci)\n", | |
| "pop2 = np.apply_along_axis(\n", | |
| " lambda p: rng.choice([1, 0], size=nsamples2, p=[p[0], 1 - p[0]]),\n", | |
| " 0,\n", | |
| " freqs_pop2.reshape(1, -1)\n", | |
| ")\n", | |
| "\n", | |
| "pop2.mean(axis=0)[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 90, | |
| "id": "given-consent", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.71669173, 0.55885922, 0.16555695, 0.11097242, 0.61347501])" | |
| ] | |
| }, | |
| "execution_count": 90, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "freqs_pop2[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 91, | |
| "id": "through-welding", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(111, 1000)" | |
| ] | |
| }, | |
| "execution_count": 91, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = np.concatenate([pop1, pop2], axis=0)\n", | |
| "X.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 92, | |
| "id": "ranking-antigua", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([616, 942, 359, 399, 389, 399, 706, 424, 476, 279])" | |
| ] | |
| }, | |
| "execution_count": 92, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "freq_diff = np.abs(freqs_pop1 - freqs_pop2)\n", | |
| "freq_diff /= freq_diff.sum()\n", | |
| "\n", | |
| "n_true_loci = 10\n", | |
| "true_effect_loci = rng.choice(np.arange(nloci), size=n_true_loci, p=freq_diff)\n", | |
| "true_effect_loci" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 93, | |
| "id": "enabling-soviet", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.00219396, 0.00155591, 0.00167062, 0.00108379, 0.00179073,\n", | |
| " 0.00108379, 0.00060603, 0.00144228, 0.00250293, 0.00146181])" | |
| ] | |
| }, | |
| "execution_count": 93, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "freq_diff[true_effect_loci]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 94, | |
| "id": "sophisticated-groove", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([ 1.58211257, -3.31679662, -1.65575129, -0.29111095, 2.48020977,\n", | |
| " -0.29111095, 0.615461 , -4.98735786, 1.59195686, -0.21220992])" | |
| ] | |
| }, | |
| "execution_count": 94, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "true_effects = np.zeros(nloci)\n", | |
| "true_effects[true_effect_loci] = rng.normal(scale=2, size=n_true_loci)\n", | |
| "true_effects[true_effect_loci]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 95, | |
| "id": "removable-myanmar", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-2.89003618, -7.37392783, 0.53139135, -2.64806911, 1.31172872,\n", | |
| " -6.03047451, -5.73292675, -5.54473542, 1.21794124, -6.08810311,\n", | |
| " -5.06613556, -2.17749576, 1.153398 , 0.58935916, -5.1086086 ,\n", | |
| " -4.13321636, -6.97651621, 1.02283966, -2.35137949, -4.94724517,\n", | |
| " -4.86981345, -2.43990861, -2.51994556, -9.27638919, -2.42012765,\n", | |
| " -4.35831236, -4.76393238, -4.85760981, -1.2375028 , -5.5706874 ,\n", | |
| " -1.95124271, -8.19967383, -5.31216132, -0.15281248, -4.90213372,\n", | |
| " -3.32718025, -3.83720197, -2.90110106, -2.24997491, -6.53388119,\n", | |
| " -4.86865215, 1.56853564, -4.84568096, -5.24571228, -4.34553181,\n", | |
| " 1.19739964, -8.96300407, -6.07413718, -5.66936917, -9.10162307,\n", | |
| " -7.626954 , -4.96612254, -9.0948901 , -4.24148847, -5.0913764 ,\n", | |
| " -4.52476997, -4.91708561, -5.7047012 , -5.18291892, -5.6231925 ,\n", | |
| " -3.74753687, -2.38644182, -5.1069895 , -4.54746666, 0.32610541,\n", | |
| " -7.53173038, -2.78363444, -6.16385426, -2.34379364, -4.144581 ,\n", | |
| " -2.5137671 , 1.75003848, -7.39191795, -2.31144608, 0.2806683 ,\n", | |
| " -6.84167982, -6.21334354, -7.54033069, -8.86755271, -1.33341253,\n", | |
| " 2.13838086, -4.36326148, -2.64128861, -5.25058532, -5.1479919 ,\n", | |
| " -3.89250774, -4.48815988, 1.14517427, -1.76937394, -7.3322457 ,\n", | |
| " -8.07716692, -2.12339758, -4.16387593, 2.66831327, -7.99565036,\n", | |
| " -6.44008276, -6.67012985, -2.36474471, -3.83451463, 0.25120136,\n", | |
| " -1.66781202, 0.64495563, 0.27051483, -4.05533118, -4.14705515,\n", | |
| " -3.96143322, -0.24173033, -3.7781134 , 1.48863394, 0.82347723,\n", | |
| " 0.71743525])" | |
| ] | |
| }, | |
| "execution_count": 95, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y = true_effects.dot(X.T)\n", | |
| "#y[:nsamples] += 2\n", | |
| "y += rng.normal(scale=0.5, size=len(y))\n", | |
| "y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 96, | |
| "id": "critical-coating", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-2.89003618, -7.37392783, 0.53139135, -2.64806911, 1.31172872,\n", | |
| " -6.03047451, -5.73292675, -5.54473542, 1.21794124, -6.08810311,\n", | |
| " -5.06613556, -2.17749576, 1.153398 , 0.58935916, -5.1086086 ,\n", | |
| " -4.13321636, -6.97651621, 1.02283966, -2.35137949, -4.94724517,\n", | |
| " -4.86981345, -2.43990861, -2.51994556, -9.27638919, -2.42012765,\n", | |
| " -4.35831236, -4.76393238, -4.85760981, -1.2375028 , -5.5706874 ,\n", | |
| " -1.95124271, -8.19967383, -5.31216132, -0.15281248, -4.90213372,\n", | |
| " -3.32718025, -3.83720197, -2.90110106, -2.24997491, -6.53388119,\n", | |
| " -4.86865215, 1.56853564, -4.84568096, -5.24571228, -4.34553181,\n", | |
| " 1.19739964, -8.96300407, -6.07413718, -5.66936917, -9.10162307,\n", | |
| " -7.626954 , -4.96612254, -9.0948901 , -4.24148847, -5.0913764 ,\n", | |
| " -4.52476997, -4.91708561, -5.7047012 , -5.18291892, -5.6231925 ,\n", | |
| " -3.74753687, -2.38644182, -5.1069895 , -4.54746666, 0.32610541,\n", | |
| " -7.53173038, -2.78363444, -6.16385426, -2.34379364, -4.144581 ,\n", | |
| " -2.5137671 , 1.75003848, -7.39191795, -2.31144608, 0.2806683 ,\n", | |
| " -6.84167982, -6.21334354, -7.54033069, -8.86755271, -1.33341253,\n", | |
| " 2.13838086, -4.36326148, -2.64128861, -5.25058532, -5.1479919 ,\n", | |
| " -3.89250774, -4.48815988, 1.14517427, -1.76937394, -7.3322457 ,\n", | |
| " -8.07716692, -2.12339758, -4.16387593, 2.66831327, -7.99565036,\n", | |
| " -6.44008276, -6.67012985, -2.36474471, -3.83451463, 0.25120136,\n", | |
| " -1.66781202, 0.64495563, 0.27051483, -4.05533118, -4.14705515,\n", | |
| " -3.96143322, -0.24173033, -3.7781134 , 1.48863394, 0.82347723,\n", | |
| " 0.71743525])" | |
| ] | |
| }, | |
| "execution_count": 96, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 97, | |
| "id": "rental-museum", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from scipy.spatial.distance import pdist, squareform" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 98, | |
| "id": "initial-generic", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import seaborn as sns" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 99, | |
| "id": "related-burner", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<seaborn.matrix.ClusterGrid at 0x7f8b4c0b3a50>" | |
| ] | |
| }, | |
| "execution_count": 99, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 720x720 with 4 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sns.clustermap(squareform(pdist(X)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 100, | |
| "id": "passing-joseph", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "from sklearn.linear_model import LinearRegression, Lasso" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 101, | |
| "id": "subject-mortality", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Lasso(alpha=0.1)" | |
| ] | |
| }, | |
| "execution_count": 101, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = Lasso(alpha=0.1)\n", | |
| "model.fit(X, y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 102, | |
| "id": "durable-three", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>y</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>pop</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>-2.890036</td>\n", | |
| " <td>-2.483981</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>-7.373928</td>\n", | |
| " <td>-5.123978</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0.531391</td>\n", | |
| " <td>-0.083482</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>-2.648069</td>\n", | |
| " <td>-2.437965</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1.311729</td>\n", | |
| " <td>0.481124</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>106</th>\n", | |
| " <td>-0.241730</td>\n", | |
| " <td>-0.149222</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>107</th>\n", | |
| " <td>-3.778113</td>\n", | |
| " <td>-4.339746</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>108</th>\n", | |
| " <td>1.488634</td>\n", | |
| " <td>0.296920</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>109</th>\n", | |
| " <td>0.823477</td>\n", | |
| " <td>0.342125</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>110</th>\n", | |
| " <td>0.717435</td>\n", | |
| " <td>0.309029</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>111 rows × 3 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " y pred pop\n", | |
| "0 -2.890036 -2.483981 1\n", | |
| "1 -7.373928 -5.123978 1\n", | |
| "2 0.531391 -0.083482 1\n", | |
| "3 -2.648069 -2.437965 1\n", | |
| "4 1.311729 0.481124 1\n", | |
| ".. ... ... ...\n", | |
| "106 -0.241730 -0.149222 2\n", | |
| "107 -3.778113 -4.339746 2\n", | |
| "108 1.488634 0.296920 2\n", | |
| "109 0.823477 0.342125 2\n", | |
| "110 0.717435 0.309029 2\n", | |
| "\n", | |
| "[111 rows x 3 columns]" | |
| ] | |
| }, | |
| "execution_count": 102, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "preds = pd.DataFrame({\"y\": y, \"pred\": model.predict(X), \"pop\": 1})\n", | |
| "preds.loc[nsamples:, \"pop\"] = 2\n", | |
| "preds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 103, | |
| "id": "latest-thanks", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "pop\n", | |
| "1 0.792984\n", | |
| "2 0.308206\n", | |
| "dtype: float64" | |
| ] | |
| }, | |
| "execution_count": 103, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "preds.groupby(\"pop\").apply(lambda p: ((p[\"y\"] - p[\"pred\"]) ** 2).mean())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 104, | |
| "id": "higher-married", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>424</th>\n", | |
| " <td>-4.987358</td>\n", | |
| " <td>-4.636665</td>\n", | |
| " <td>0.782155</td>\n", | |
| " <td>0.314272</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>942</th>\n", | |
| " <td>-3.316797</td>\n", | |
| " <td>-2.358242</td>\n", | |
| " <td>0.294759</td>\n", | |
| " <td>0.799501</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>359</th>\n", | |
| " <td>-1.655751</td>\n", | |
| " <td>-0.447206</td>\n", | |
| " <td>0.902841</td>\n", | |
| " <td>0.360887</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>399</th>\n", | |
| " <td>-0.291111</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.528208</td>\n", | |
| " <td>0.879795</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>279</th>\n", | |
| " <td>-0.212210</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.437959</td>\n", | |
| " <td>0.912175</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>706</th>\n", | |
| " <td>0.615461</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.589905</td>\n", | |
| " <td>0.786503</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>616</th>\n", | |
| " <td>1.582113</td>\n", | |
| " <td>0.526349</td>\n", | |
| " <td>0.082899</td>\n", | |
| " <td>0.794628</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>476</th>\n", | |
| " <td>1.591957</td>\n", | |
| " <td>0.008964</td>\n", | |
| " <td>0.873217</td>\n", | |
| " <td>0.061256</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>389</th>\n", | |
| " <td>2.480210</td>\n", | |
| " <td>1.641444</td>\n", | |
| " <td>0.239744</td>\n", | |
| " <td>0.820666</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true pred freq1 freq2\n", | |
| "424 -4.987358 -4.636665 0.782155 0.314272\n", | |
| "942 -3.316797 -2.358242 0.294759 0.799501\n", | |
| "359 -1.655751 -0.447206 0.902841 0.360887\n", | |
| "399 -0.291111 0.000000 0.528208 0.879795\n", | |
| "279 -0.212210 -0.000000 0.437959 0.912175\n", | |
| "706 0.615461 0.000000 0.589905 0.786503\n", | |
| "616 1.582113 0.526349 0.082899 0.794628\n", | |
| "476 1.591957 0.008964 0.873217 0.061256\n", | |
| "389 2.480210 1.641444 0.239744 0.820666" | |
| ] | |
| }, | |
| "execution_count": 104, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects = pd.DataFrame({\"true\": true_effects, \"pred\": model.coef_, \"freq1\": freqs_pop1, \"freq2\": freqs_pop2})\n", | |
| "effects[effects[\"true\"] != 0].sort_values(\"true\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 105, | |
| "id": "sought-leader", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>424</th>\n", | |
| " <td>-4.987358</td>\n", | |
| " <td>-4.636665</td>\n", | |
| " <td>0.782155</td>\n", | |
| " <td>0.314272</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>942</th>\n", | |
| " <td>-3.316797</td>\n", | |
| " <td>-2.358242</td>\n", | |
| " <td>0.294759</td>\n", | |
| " <td>0.799501</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>359</th>\n", | |
| " <td>-1.655751</td>\n", | |
| " <td>-0.447206</td>\n", | |
| " <td>0.902841</td>\n", | |
| " <td>0.360887</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>984</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.131787</td>\n", | |
| " <td>0.474286</td>\n", | |
| " <td>0.452758</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>529</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.059976</td>\n", | |
| " <td>0.556470</td>\n", | |
| " <td>0.096112</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>847</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.058744</td>\n", | |
| " <td>0.647811</td>\n", | |
| " <td>0.407043</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>533</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.045375</td>\n", | |
| " <td>0.616533</td>\n", | |
| " <td>0.053800</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>664</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.734732</td>\n", | |
| " <td>0.037101</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>663</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.468895</td>\n", | |
| " <td>0.420180</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>662</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.285418</td>\n", | |
| " <td>0.819648</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true pred freq1 freq2\n", | |
| "424 -4.987358 -4.636665 0.782155 0.314272\n", | |
| "942 -3.316797 -2.358242 0.294759 0.799501\n", | |
| "359 -1.655751 -0.447206 0.902841 0.360887\n", | |
| "984 0.000000 -0.131787 0.474286 0.452758\n", | |
| "529 0.000000 -0.059976 0.556470 0.096112\n", | |
| "847 0.000000 -0.058744 0.647811 0.407043\n", | |
| "533 0.000000 -0.045375 0.616533 0.053800\n", | |
| "664 0.000000 0.000000 0.734732 0.037101\n", | |
| "663 0.000000 -0.000000 0.468895 0.420180\n", | |
| "662 0.000000 0.000000 0.285418 0.819648" | |
| ] | |
| }, | |
| "execution_count": 105, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects.sort_values(\"pred\").head(10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 106, | |
| "id": "differential-miller", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>337</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.572401</td>\n", | |
| " <td>0.765802</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>476</th>\n", | |
| " <td>1.591957</td>\n", | |
| " <td>0.008964</td>\n", | |
| " <td>0.873217</td>\n", | |
| " <td>0.061256</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>925</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.021720</td>\n", | |
| " <td>0.346882</td>\n", | |
| " <td>0.865417</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>744</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.033096</td>\n", | |
| " <td>0.587489</td>\n", | |
| " <td>0.797726</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>862</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.045205</td>\n", | |
| " <td>0.328772</td>\n", | |
| " <td>0.678966</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>755</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.061040</td>\n", | |
| " <td>0.641671</td>\n", | |
| " <td>0.571747</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>85</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.074928</td>\n", | |
| " <td>0.636585</td>\n", | |
| " <td>0.817731</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>111</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.162889</td>\n", | |
| " <td>0.515959</td>\n", | |
| " <td>0.974238</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>616</th>\n", | |
| " <td>1.582113</td>\n", | |
| " <td>0.526349</td>\n", | |
| " <td>0.082899</td>\n", | |
| " <td>0.794628</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>389</th>\n", | |
| " <td>2.480210</td>\n", | |
| " <td>1.641444</td>\n", | |
| " <td>0.239744</td>\n", | |
| " <td>0.820666</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true pred freq1 freq2\n", | |
| "337 0.000000 0.000000 0.572401 0.765802\n", | |
| "476 1.591957 0.008964 0.873217 0.061256\n", | |
| "925 0.000000 0.021720 0.346882 0.865417\n", | |
| "744 0.000000 0.033096 0.587489 0.797726\n", | |
| "862 0.000000 0.045205 0.328772 0.678966\n", | |
| "755 0.000000 0.061040 0.641671 0.571747\n", | |
| "85 0.000000 0.074928 0.636585 0.817731\n", | |
| "111 0.000000 0.162889 0.515959 0.974238\n", | |
| "616 1.582113 0.526349 0.082899 0.794628\n", | |
| "389 2.480210 1.641444 0.239744 0.820666" | |
| ] | |
| }, | |
| "execution_count": 106, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects.sort_values(\"pred\").tail(10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "wound-converter", | |
| "metadata": {}, | |
| "source": [ | |
| "RelieF" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 107, | |
| "id": "known-lotus", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dists = squareform(pdist(X))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 108, | |
| "id": "divine-capacity", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "nneighbors = 5\n", | |
| "weights = np.zeros(nloci)\n", | |
| "\n", | |
| "k = 1\n", | |
| "n = nsamples + nsamples2\n", | |
| "std = np.std(y)\n", | |
| "for i in range(n):\n", | |
| " closest_samples = np.argsort(dists[i])\n", | |
| "\n", | |
| " nhits = 0\n", | |
| " nmiss = 0\n", | |
| " for j in closest_samples:\n", | |
| " if i == j:\n", | |
| " continue\n", | |
| " elif (nhits > k) and (nmiss > k):\n", | |
| " break\n", | |
| "\n", | |
| " xij = np.abs(X[i,] - X[j,])\n", | |
| " yij = y[i] - y[j]\n", | |
| "\n", | |
| " if abs(yij) < std:\n", | |
| " # Its a hit\n", | |
| " if nhits > k:\n", | |
| " continue\n", | |
| " weights -= xij / (n * k)\n", | |
| " nhits += 1\n", | |
| " else:\n", | |
| " if nmiss > k:\n", | |
| " continue\n", | |
| " weights += xij / (n * k)\n", | |
| " nmiss += 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 109, | |
| "id": "downtown-sellers", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " <th>relief</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>424</th>\n", | |
| " <td>-4.987358</td>\n", | |
| " <td>-4.636665</td>\n", | |
| " <td>0.782155</td>\n", | |
| " <td>0.314272</td>\n", | |
| " <td>0.963964</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>942</th>\n", | |
| " <td>-3.316797</td>\n", | |
| " <td>-2.358242</td>\n", | |
| " <td>0.294759</td>\n", | |
| " <td>0.799501</td>\n", | |
| " <td>0.360360</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>382</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.434281</td>\n", | |
| " <td>0.122560</td>\n", | |
| " <td>0.351351</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>681</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.338203</td>\n", | |
| " <td>0.177349</td>\n", | |
| " <td>0.333333</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>130</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.492619</td>\n", | |
| " <td>0.430562</td>\n", | |
| " <td>0.279279</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>126</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.348554</td>\n", | |
| " <td>0.303019</td>\n", | |
| " <td>-0.261261</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>163</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.333717</td>\n", | |
| " <td>0.824376</td>\n", | |
| " <td>-0.261261</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>695</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.698863</td>\n", | |
| " <td>0.816582</td>\n", | |
| " <td>-0.279279</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>298</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.500859</td>\n", | |
| " <td>0.073400</td>\n", | |
| " <td>-0.279279</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>762</th>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.490440</td>\n", | |
| " <td>0.624082</td>\n", | |
| " <td>-0.297297</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>1000 rows × 5 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true pred freq1 freq2 relief\n", | |
| "424 -4.987358 -4.636665 0.782155 0.314272 0.963964\n", | |
| "942 -3.316797 -2.358242 0.294759 0.799501 0.360360\n", | |
| "382 0.000000 -0.000000 0.434281 0.122560 0.351351\n", | |
| "681 0.000000 0.000000 0.338203 0.177349 0.333333\n", | |
| "130 0.000000 0.000000 0.492619 0.430562 0.279279\n", | |
| ".. ... ... ... ... ...\n", | |
| "126 0.000000 -0.000000 0.348554 0.303019 -0.261261\n", | |
| "163 0.000000 0.000000 0.333717 0.824376 -0.261261\n", | |
| "695 0.000000 0.000000 0.698863 0.816582 -0.279279\n", | |
| "298 0.000000 -0.000000 0.500859 0.073400 -0.279279\n", | |
| "762 0.000000 -0.000000 0.490440 0.624082 -0.297297\n", | |
| "\n", | |
| "[1000 rows x 5 columns]" | |
| ] | |
| }, | |
| "execution_count": 109, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects = pd.DataFrame({\"true\": true_effects, \"pred\": model.coef_, \"freq1\": freqs_pop1, \"freq2\": freqs_pop2, \"relief\": weights})\n", | |
| "effects.sort_values(\"relief\", ascending=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 110, | |
| "id": "european-contractor", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>pred</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " <th>relief</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>279</th>\n", | |
| " <td>-0.212210</td>\n", | |
| " <td>-0.000000</td>\n", | |
| " <td>0.437959</td>\n", | |
| " <td>0.912175</td>\n", | |
| " <td>0.027027</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>359</th>\n", | |
| " <td>-1.655751</td>\n", | |
| " <td>-0.447206</td>\n", | |
| " <td>0.902841</td>\n", | |
| " <td>0.360887</td>\n", | |
| " <td>0.054054</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>389</th>\n", | |
| " <td>2.480210</td>\n", | |
| " <td>1.641444</td>\n", | |
| " <td>0.239744</td>\n", | |
| " <td>0.820666</td>\n", | |
| " <td>0.090090</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>399</th>\n", | |
| " <td>-0.291111</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.528208</td>\n", | |
| " <td>0.879795</td>\n", | |
| " <td>0.036036</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>424</th>\n", | |
| " <td>-4.987358</td>\n", | |
| " <td>-4.636665</td>\n", | |
| " <td>0.782155</td>\n", | |
| " <td>0.314272</td>\n", | |
| " <td>0.963964</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>476</th>\n", | |
| " <td>1.591957</td>\n", | |
| " <td>0.008964</td>\n", | |
| " <td>0.873217</td>\n", | |
| " <td>0.061256</td>\n", | |
| " <td>-0.135135</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>616</th>\n", | |
| " <td>1.582113</td>\n", | |
| " <td>0.526349</td>\n", | |
| " <td>0.082899</td>\n", | |
| " <td>0.794628</td>\n", | |
| " <td>0.045045</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>706</th>\n", | |
| " <td>0.615461</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.589905</td>\n", | |
| " <td>0.786503</td>\n", | |
| " <td>0.198198</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>942</th>\n", | |
| " <td>-3.316797</td>\n", | |
| " <td>-2.358242</td>\n", | |
| " <td>0.294759</td>\n", | |
| " <td>0.799501</td>\n", | |
| " <td>0.360360</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true pred freq1 freq2 relief\n", | |
| "279 -0.212210 -0.000000 0.437959 0.912175 0.027027\n", | |
| "359 -1.655751 -0.447206 0.902841 0.360887 0.054054\n", | |
| "389 2.480210 1.641444 0.239744 0.820666 0.090090\n", | |
| "399 -0.291111 0.000000 0.528208 0.879795 0.036036\n", | |
| "424 -4.987358 -4.636665 0.782155 0.314272 0.963964\n", | |
| "476 1.591957 0.008964 0.873217 0.061256 -0.135135\n", | |
| "616 1.582113 0.526349 0.082899 0.794628 0.045045\n", | |
| "706 0.615461 0.000000 0.589905 0.786503 0.198198\n", | |
| "942 -3.316797 -2.358242 0.294759 0.799501 0.360360" | |
| ] | |
| }, | |
| "execution_count": 110, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects[effects[\"true\"] != 0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 111, | |
| "id": "raised-charter", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "count 1.000000e+03\n", | |
| "mean 4.630631e-03\n", | |
| "std 1.004398e-01\n", | |
| "min -2.972973e-01\n", | |
| "25% -5.405405e-02\n", | |
| "50% 3.469447e-18\n", | |
| "75% 6.306306e-02\n", | |
| "max 9.639640e-01\n", | |
| "Name: relief, dtype: float64" | |
| ] | |
| }, | |
| "execution_count": 111, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects[\"relief\"].describe()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "industrial-christmas", | |
| "metadata": {}, | |
| "source": [ | |
| "[SKrebate](https://epistasislab.github.io/scikit-rebate/using/) has more algorithms but they are really SLOW.\n", | |
| "It's python loops all the way down.\n", | |
| "\n", | |
| "Plenty of other implementations in R." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 112, | |
| "id": "solved-second", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sim = X.dot(X.T).astype(float)\n", | |
| "sim -= sim.min()\n", | |
| "sim /= 50 #np.percentile(sim, 25)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 113, | |
| "id": "twelve-scheduling", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<AxesSubplot:>" | |
| ] | |
| }, | |
| "execution_count": 113, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sns.heatmap(sim)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 114, | |
| "id": "regular-remove", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from itertools import combinations\n", | |
| "\n", | |
| "X_pairs = []\n", | |
| "y_pairs = []\n", | |
| "pair_sim = []\n", | |
| "\n", | |
| "for i, j in combinations(range(X.shape[0]), 2):\n", | |
| " if i == j:\n", | |
| " continue\n", | |
| "\n", | |
| " if y[i] > y[j]:\n", | |
| " y_pairij = 1\n", | |
| " elif np.abs(y[i] - y[j]) < np.std(y):\n", | |
| " y_pairij = 0.5\n", | |
| " else:\n", | |
| " y_pairij = 0\n", | |
| "\n", | |
| " X_pairs.append((X[i], X[j]))\n", | |
| " y_pairs.append(y_pairij)\n", | |
| " pair_sim.append(sim[i, j])\n", | |
| " \n", | |
| "X_pairs = np.array(X_pairs)\n", | |
| "y_pairs = np.array(y_pairs)\n", | |
| "pair_sim = np.array(pair_sim)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 115, | |
| "id": "elect-terrorist", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tf.keras.backend.clear_session()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 116, | |
| "id": "drawn-sapphire", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tf.Tensor(5.658346, shape=(), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "optimizer = Adam(learning_rate=1)\n", | |
| "loss_fn = BinaryCrossentropy(from_logits=True)\n", | |
| "\n", | |
| "model = Sequential([Dense(1, activation=\"linear\", use_bias=False, kernel_regularizer=tf.keras.regularizers.l1(1))])\n", | |
| "\n", | |
| "n_epoch = 100\n", | |
| "for _ in range(n_epoch):\n", | |
| " with tf.GradientTape() as tape:\n", | |
| " s1 = model(np.abs(X_pairs[:, 0, :] - X_pairs[:, 1, :]))\n", | |
| " #s2 = model(X_pairs[:, 1, :])\n", | |
| " pred = tfsigmoid(s1)\n", | |
| " loss_value = loss_fn(y_pairs.reshape(-1, 1), pred, sample_weight=pair_sim)\n", | |
| "\n", | |
| " grads = tape.gradient(loss_value, model.trainable_variables)\n", | |
| " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", | |
| "\n", | |
| "print(loss_value)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 117, | |
| "id": "declared-recycling", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " <th>relief</th>\n", | |
| " <th>nn</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>869</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.996953</td>\n", | |
| " <td>0.807186</td>\n", | |
| " <td>0.009009</td>\n", | |
| " <td>7.943168</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>484</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.036648</td>\n", | |
| " <td>0.687834</td>\n", | |
| " <td>-0.072072</td>\n", | |
| " <td>4.447771</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>793</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.998612</td>\n", | |
| " <td>0.596046</td>\n", | |
| " <td>0.036036</td>\n", | |
| " <td>3.997329</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>894</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.097156</td>\n", | |
| " <td>0.121705</td>\n", | |
| " <td>0.063063</td>\n", | |
| " <td>3.377757</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>561</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.003345</td>\n", | |
| " <td>0.196184</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>3.357172</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>530</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.959838</td>\n", | |
| " <td>0.461836</td>\n", | |
| " <td>0.063063</td>\n", | |
| " <td>-3.173112</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>675</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.994988</td>\n", | |
| " <td>0.836241</td>\n", | |
| " <td>0.018018</td>\n", | |
| " <td>-3.580165</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>625</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.997831</td>\n", | |
| " <td>0.423959</td>\n", | |
| " <td>0.045045</td>\n", | |
| " <td>-3.708804</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>742</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.994504</td>\n", | |
| " <td>0.051681</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>-4.116137</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>779</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.031603</td>\n", | |
| " <td>0.928924</td>\n", | |
| " <td>0.009009</td>\n", | |
| " <td>-4.695507</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>1000 rows × 5 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true freq1 freq2 relief nn\n", | |
| "869 0.0 0.996953 0.807186 0.009009 7.943168\n", | |
| "484 0.0 0.036648 0.687834 -0.072072 4.447771\n", | |
| "793 0.0 0.998612 0.596046 0.036036 3.997329\n", | |
| "894 0.0 0.097156 0.121705 0.063063 3.377757\n", | |
| "561 0.0 0.003345 0.196184 0.000000 3.357172\n", | |
| ".. ... ... ... ... ...\n", | |
| "530 0.0 0.959838 0.461836 0.063063 -3.173112\n", | |
| "675 0.0 0.994988 0.836241 0.018018 -3.580165\n", | |
| "625 0.0 0.997831 0.423959 0.045045 -3.708804\n", | |
| "742 0.0 0.994504 0.051681 0.000000 -4.116137\n", | |
| "779 0.0 0.031603 0.928924 0.009009 -4.695507\n", | |
| "\n", | |
| "[1000 rows x 5 columns]" | |
| ] | |
| }, | |
| "execution_count": 117, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects = pd.DataFrame({\"true\": true_effects, \"freq1\": freqs_pop1, \"freq2\": freqs_pop2, \"relief\": weights, \"nn\": model.get_weights()[0][:, 0]})\n", | |
| "effects.sort_values(\"nn\", ascending=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "vital-exchange", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 118, | |
| "id": "north-winning", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>true</th>\n", | |
| " <th>freq1</th>\n", | |
| " <th>freq2</th>\n", | |
| " <th>relief</th>\n", | |
| " <th>nn</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>279</th>\n", | |
| " <td>-0.212210</td>\n", | |
| " <td>0.437959</td>\n", | |
| " <td>0.912175</td>\n", | |
| " <td>0.027027</td>\n", | |
| " <td>-0.100957</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>359</th>\n", | |
| " <td>-1.655751</td>\n", | |
| " <td>0.902841</td>\n", | |
| " <td>0.360887</td>\n", | |
| " <td>0.054054</td>\n", | |
| " <td>-2.756028</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>389</th>\n", | |
| " <td>2.480210</td>\n", | |
| " <td>0.239744</td>\n", | |
| " <td>0.820666</td>\n", | |
| " <td>0.090090</td>\n", | |
| " <td>-0.689894</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>399</th>\n", | |
| " <td>-0.291111</td>\n", | |
| " <td>0.528208</td>\n", | |
| " <td>0.879795</td>\n", | |
| " <td>0.036036</td>\n", | |
| " <td>-0.063009</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>424</th>\n", | |
| " <td>-4.987358</td>\n", | |
| " <td>0.782155</td>\n", | |
| " <td>0.314272</td>\n", | |
| " <td>0.963964</td>\n", | |
| " <td>-2.341457</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>476</th>\n", | |
| " <td>1.591957</td>\n", | |
| " <td>0.873217</td>\n", | |
| " <td>0.061256</td>\n", | |
| " <td>-0.135135</td>\n", | |
| " <td>0.066689</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>616</th>\n", | |
| " <td>1.582113</td>\n", | |
| " <td>0.082899</td>\n", | |
| " <td>0.794628</td>\n", | |
| " <td>0.045045</td>\n", | |
| " <td>-2.427979</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>706</th>\n", | |
| " <td>0.615461</td>\n", | |
| " <td>0.589905</td>\n", | |
| " <td>0.786503</td>\n", | |
| " <td>0.198198</td>\n", | |
| " <td>-0.165121</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>942</th>\n", | |
| " <td>-3.316797</td>\n", | |
| " <td>0.294759</td>\n", | |
| " <td>0.799501</td>\n", | |
| " <td>0.360360</td>\n", | |
| " <td>1.856650</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " true freq1 freq2 relief nn\n", | |
| "279 -0.212210 0.437959 0.912175 0.027027 -0.100957\n", | |
| "359 -1.655751 0.902841 0.360887 0.054054 -2.756028\n", | |
| "389 2.480210 0.239744 0.820666 0.090090 -0.689894\n", | |
| "399 -0.291111 0.528208 0.879795 0.036036 -0.063009\n", | |
| "424 -4.987358 0.782155 0.314272 0.963964 -2.341457\n", | |
| "476 1.591957 0.873217 0.061256 -0.135135 0.066689\n", | |
| "616 1.582113 0.082899 0.794628 0.045045 -2.427979\n", | |
| "706 0.615461 0.589905 0.786503 0.198198 -0.165121\n", | |
| "942 -3.316797 0.294759 0.799501 0.360360 1.856650" | |
| ] | |
| }, | |
| "execution_count": 118, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects[effects[\"true\"] != 0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 119, | |
| "id": "extra-composite", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "count 1000.000000\n", | |
| "mean -0.027586\n", | |
| "std 0.865034\n", | |
| "min -4.695507\n", | |
| "25% -0.400552\n", | |
| "50% -0.027578\n", | |
| "75% 0.361622\n", | |
| "max 7.943168\n", | |
| "Name: nn, dtype: float64" | |
| ] | |
| }, | |
| "execution_count": 119, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "effects[\"nn\"].describe()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 120, | |
| "id": "simple-transition", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<AxesSubplot:xlabel='nn', ylabel='Density'>" | |
| ] | |
| }, | |
| "execution_count": 120, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sns.kdeplot(effects[\"nn\"])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "bottom-structure", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python [conda env:.conda-condaenv]", | |
| "language": "python", | |
| "name": "conda-env-.conda-condaenv-py" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment