Last active
November 25, 2022 06:15
-
-
Save ljmartin/d52eb45cd07d657c4469a83911a6590a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "2c411376-8932-49d0-bcfa-ac85428db470", | |
"metadata": {}, | |
"source": [ | |
"# Logistic regression on Morgan fingerprints using duckdb\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "4f7259df-1817-4030-8a5d-5dbe96211213", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rdkit import Chem\n", | |
"from rdkit.Chem import AllChem, rdMolDescriptors\n", | |
"\n", | |
"import duckdb\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from scipy.special import expit\n", | |
"\n", | |
"from sklearn.linear_model import LogisticRegression\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "70eb937d-e4ad-4854-863c-c4e944a595eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#load a sample of enamine molecules\n", | |
"\n", | |
"df = pd.read_csv('./456.smi', sep='\\t',\n", | |
" names=['smiles', 'idnumber'])\n", | |
"\n", | |
"df['mols'] = [Chem.MolFromSmiles(i) for i in df['smiles']]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "e617eee1-6c8a-4be0-9aab-20f123a51dc9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#add a column to the df that is simply a list of the On bits from each fingerprint. \n", | |
"\n", | |
"fp_bits = []\n", | |
"\n", | |
"for mol in df['mols']:\n", | |
" bits = list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048).GetOnBits())\n", | |
" fp_bits.append(bits)\n", | |
"df['fps'] = fp_bits\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "b3f9d924-ec52-41e7-96d0-18b532e45939", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#in this simple example, we will regress against cLogP > 3\n", | |
"df['clogp'] = [rdMolDescriptors.CalcCrippenDescriptors(m)[0] for m in df['mols']]\n", | |
"\n", | |
"df['label'] = df['clogp'] > 3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "8e7be7e5-7b9f-46c1-8136-897f0cce0c01", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>smiles</th>\n", | |
" <th>idnumber</th>\n", | |
" <th>mols</th>\n", | |
" <th>fps</th>\n", | |
" <th>clogp</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>CCOC(C)C1=NC(CN(C)C(C)C2=CC=CC(OC)=C2)=CS1</td>\n", | |
" <td>Z1233971164</td>\n", | |
" <td><rdkit.Chem.rdchem.Mol object at 0x18228f6a0></td>\n", | |
" <td>[1, 31, 52, 80, 125, 283, 294, 315, 316, 317, ...</td>\n", | |
" <td>4.44220</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>CC1CN(C(=O)C2CCCC2)CCN1C(=O)C1=C(F)C(N)=CC=C1</td>\n", | |
" <td>PV-005683708030</td>\n", | |
" <td><rdkit.Chem.rdchem.Mol object at 0x18228f640></td>\n", | |
" <td>[47, 53, 127, 242, 306, 454, 494, 650, 699, 73...</td>\n", | |
" <td>2.27100</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>CC1(C)CN(C(=O)CCC2CC2)CC1NC(=O)C12COC(CF)(C1)C2</td>\n", | |
" <td>PV-003367388480</td>\n", | |
" <td><rdkit.Chem.rdchem.Mol object at 0x18228f820></td>\n", | |
" <td>[29, 80, 114, 174, 246, 310, 351, 387, 432, 47...</td>\n", | |
" <td>2.04850</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>COC1(CC(=O)N[C@H]2COCC[C@H]2NC(=O)C2CCSCC2)CCC...</td>\n", | |
" <td>PV-006682455241</td>\n", | |
" <td><rdkit.Chem.rdchem.Mol object at 0x18228f940></td>\n", | |
" <td>[41, 47, 80, 83, 233, 292, 387, 422, 523, 650,...</td>\n", | |
" <td>1.86880</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>CC1=CC=C(CC2=NOC(CCNC3=CC=CC=C3[N+](=O)[O-])=N...</td>\n", | |
" <td>Z1683337578</td>\n", | |
" <td><rdkit.Chem.rdchem.Mol object at 0x18228fca0></td>\n", | |
" <td>[80, 91, 112, 164, 227, 235, 248, 310, 354, 37...</td>\n", | |
" <td>3.53162</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" smiles idnumber \\\n", | |
"0 CCOC(C)C1=NC(CN(C)C(C)C2=CC=CC(OC)=C2)=CS1 Z1233971164 \n", | |
"1 CC1CN(C(=O)C2CCCC2)CCN1C(=O)C1=C(F)C(N)=CC=C1 PV-005683708030 \n", | |
"2 CC1(C)CN(C(=O)CCC2CC2)CC1NC(=O)C12COC(CF)(C1)C2 PV-003367388480 \n", | |
"3 COC1(CC(=O)N[C@H]2COCC[C@H]2NC(=O)C2CCSCC2)CCC... PV-006682455241 \n", | |
"4 CC1=CC=C(CC2=NOC(CCNC3=CC=CC=C3[N+](=O)[O-])=N... Z1683337578 \n", | |
"\n", | |
" mols \\\n", | |
"0 <rdkit.Chem.rdchem.Mol object at 0x18228f6a0> \n", | |
"1 <rdkit.Chem.rdchem.Mol object at 0x18228f640> \n", | |
"2 <rdkit.Chem.rdchem.Mol object at 0x18228f820> \n", | |
"3 <rdkit.Chem.rdchem.Mol object at 0x18228f940> \n", | |
"4 <rdkit.Chem.rdchem.Mol object at 0x18228fca0> \n", | |
"\n", | |
" fps clogp label \n", | |
"0 [1, 31, 52, 80, 125, 283, 294, 315, 316, 317, ... 4.44220 True \n", | |
"1 [47, 53, 127, 242, 306, 454, 494, 650, 699, 73... 2.27100 False \n", | |
"2 [29, 80, 114, 174, 246, 310, 351, 387, 432, 47... 2.04850 False \n", | |
"3 [41, 47, 80, 83, 233, 292, 387, 422, 523, 650,... 1.86880 False \n", | |
"4 [80, 91, 112, 164, 227, 235, 248, 310, 354, 37... 3.53162 True " | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#take a look at what we have so far\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "e84d5991-cee8-4886-8512-a346e83e7998", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# turn the On bits into a dense matrix for ingestion with sklearn\n", | |
"\n", | |
"df['n'] = df['fps'].str.len()\n", | |
"\n", | |
"xidx = np.repeat( \n", | |
" np.arange(len(df)), \n", | |
" df['n']\n", | |
")\n", | |
"yidx = np.concatenate(df['fps'].values)\n", | |
"\n", | |
"fps = np.zeros([\n", | |
" len(df), \n", | |
" 2048\n", | |
"]).astype(int)\n", | |
"\n", | |
"fps[xidx, yidx] = 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "4df38dff-74c8-4009-ba81-42412c8a20cf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"LogisticRegression(max_iter=10000)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#fit the sklearn model. We will use this to verify the duckdb implementation is equivalent \n", | |
"model = LogisticRegression(max_iter = 10_000)\n", | |
"model.fit(\n", | |
" fps, \n", | |
" df['label']\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c343bc15-5f3f-431d-a4f6-71fbb1bec3df", | |
"metadata": {}, | |
"source": [ | |
"# duckdb logreg\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "bbd271e7-4117-4be4-b248-7f32c367ca88", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<duckdb.DuckDBPyConnection at 0x1824f5070>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"cursor = duckdb.connect()\n", | |
"\n", | |
"#create a table with one data column, where each row can be a LIST of doubles\n", | |
"cursor.execute(\"\"\"CREATE TABLE logreg(coefs DOUBLE[]);\"\"\")\n", | |
"\n", | |
"#into this table, insert a single row: the array of fitted coefficients from the sklearn model \n", | |
"#(it gets converted to duckdb list in the process)\n", | |
"coef_df = pd.DataFrame.from_dict({\"coefs\":[model.coef_[0]]})\n", | |
"cursor.execute('INSERT INTO logreg VALUES ((select * from coef_df));')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "cffae1f8-6765-4e3e-aa45-8ea5ca7e3edf", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1 31 52 80 125 283 294 315 316 317 322 378 507 552 562 577 583 675 695 724 781 841 875 881 1039 1052 1057 1088 1091 1163 1238 1241 1365 1380 1536 1567 1581 1595 1656 1682 1750 1819 1855 1873 1876 1972 1988\t" | |
] | |
} | |
], | |
"source": [ | |
"#separately, create another new table that has lists of fingerprint bits as all of it's rows:\n", | |
"cursor.execute(\"CREATE TABLE df AS SELECT fps FROM df\")\n", | |
"\n", | |
"#example of what this looks like (take the first row and print it):\n", | |
"print(*(cursor.execute(\"SELECT fps FROM df LIMIT 1\").fetchall()[0][0]), end='\\t')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "ff192a13-f54a-4b80-a954-151bcc08fa61", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1 31 52 80 125 283 294 315 316 317 322 378 507 552 562 577 583 675 695 724 781 841 875 881 1039 1052 1057 1088 1091 1163 1238 1241 1365 1380 1536 1567 1581 1595 1656 1682 1750 1819 1855 1873 1876 1972 1988\t" | |
] | |
} | |
], | |
"source": [ | |
"#verify by eye that these bits are the same:\n", | |
"print(*df.iloc[0]['fps'], end='\\t')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "da0f18b5-320e-446e-9d6b-af9fe85c2df6", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"# crucial step - implementing `predict_proba`:\n", | |
"\n", | |
"this isn't pretty, but it's just doing the logistic regression prediction equation:\n", | |
"\n", | |
"```python\n", | |
"from scipy.special import expit\n", | |
"expit( (model.coef_[0]*fps).sum(1) - model.intercept_[0] )\n", | |
"```\n", | |
"\n", | |
"Why is it so hacky-looking? Duckdb doesn't have expit. But expit is just:\n", | |
"\n", | |
"```python\n", | |
"1 / ( 1+np.exp(-x) )\n", | |
"```\n", | |
"\n", | |
"well, duckdb apparently doesn't have `exp` either. `np.exp(x)` is just:\n", | |
"```python\n", | |
"e**x\n", | |
"```\n", | |
"\n", | |
"and duckdb doesn't have **e** either!! Which is why this looks so messy. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "0d72ceb2-f788-48e6-9906-b0da2720d715", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#predict on all the fingerprints, using the On bits\n", | |
"result = cursor.execute(f\"\"\"select 1/(1+ pow(2.71828, -list_aggregate([coefs[i+1] for i in fps], 'sum') - {model.intercept_[0]}))\n", | |
" from logreg, df;\"\"\")\n", | |
"proba_duckdb = result.fetchall()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "b7d03749-0bd0-441c-b8df-8239956a7cdf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#predict using the equation written into the duckdb (as a sanity check)\n", | |
"proba_manual = expit((model.coef_[0]*fps).sum(1)+model.intercept_[0])\n", | |
"#predict using sklearn's method\n", | |
"proba_sklearn = model.predict_proba(fps)[:,1]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "52eee0d9-4f50-41c3-8a22-c3d99d933fbf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#are they all the same? Yes!\n", | |
"assert np.allclose(proba_manual, proba_sklearn, proba_duckdb), \"It didn't work\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ad9408bf-d837-4858-83ed-8ce42340e757", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment