Created
March 7, 2023 12:32
-
-
Save MaxHalford/823c4e7f9216607dc853724ec74ec692 to your computer and use it in GitHub Desktop.
Online gradient descent written in SQL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# OGD in SQL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 188, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[*********************100%***********************] 1 of 1 completed\n", | |
"| Date | Open | High | Low | Close | Adj Close | Volume |\n", | |
"|:--------------------|--------:|--------:|--------:|--------:|------------:|---------:|\n", | |
"| 2021-12-27 00:00:00 | 6.01218 | 6.09808 | 6.04943 | 6.1254 | 6.07454 | 1.18628 |\n", | |
"| 2021-12-28 00:00:00 | 6.1164 | 6.12883 | 6.09931 | 6.09008 | 6.03951 | 1.25318 |\n", | |
"| 2021-12-29 00:00:00 | 6.08822 | 6.10517 | 6.08598 | 6.09314 | 6.04254 | 0.987236 |\n", | |
"| 2021-12-30 00:00:00 | 6.09298 | 6.10315 | 6.08427 | 6.05305 | 6.00279 | 0.946449 |\n", | |
"| 2021-12-31 00:00:00 | 6.04613 | 6.05785 | 6.05592 | 6.03165 | 5.98157 | 1.01437 |\n" | |
] | |
} | |
], | |
"source": [ | |
"import yfinance as yf\n", | |
"\n", | |
"figures = yf.download(\n", | |
" tickers=['AAPL'],\n", | |
" start='2020-01-01',\n", | |
" end='2022-01-01'\n", | |
")\n", | |
"figures = figures / figures.std()\n", | |
"print(figures.tail().to_markdown())" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Mean" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 200, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"┌───────┬───────────────────┬────────────────────┐\n", | |
"│ step │ x │ avg │\n", | |
"│ int64 │ double │ double │\n", | |
"├───────┼───────────────────┼────────────────────┤\n", | |
"│ 505 │ 5.981568542028378 │ 3.9577706471349923 │\n", | |
"│ 504 │ 6.002789566151079 │ 3.953755175121315 │\n", | |
"│ 503 │ 6.042539700173864 │ 3.949681548101375 │\n", | |
"│ 502 │ 6.039508125299193 │ 3.945512507957804 │\n", | |
"│ 501 │ 6.074541325571636 │ 3.941332875987063 │\n", | |
"└───────┴───────────────────┴────────────────────┘\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"import duckdb\n", | |
"\n", | |
"duckdb.sql('''\n", | |
"WITH RECURSIVE\n", | |
" stream AS (\n", | |
" SELECT ROW_NUMBER() OVER () AS step, \"Adj Close\" AS x\n", | |
" FROM figures\n", | |
" ORDER BY step\n", | |
" ),\n", | |
" state(step, x, avg) AS (\n", | |
" -- Initialize\n", | |
" SELECT step, x, x AS avg\n", | |
" FROM stream\n", | |
" WHERE step = 1\n", | |
" UNION ALL\n", | |
" -- Update\n", | |
" SELECT\n", | |
" stream.step,\n", | |
" stream.x,\n", | |
" state.avg + (stream.x - state.avg) / stream.step AS avg\n", | |
" FROM stream\n", | |
" INNER JOIN state ON state.step + 1 = stream.step\n", | |
" )\n", | |
"\n", | |
"SELECT *\n", | |
"FROM state\n", | |
"ORDER BY step DESC\n", | |
"LIMIT 5\n", | |
"''').show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 206, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Date\n", | |
"2021-12-31 3.957771\n", | |
"2021-12-30 3.953755\n", | |
"2021-12-29 3.949682\n", | |
"2021-12-28 3.945513\n", | |
"2021-12-27 3.941333\n", | |
"Name: Adj Close, dtype: float64" | |
] | |
}, | |
"execution_count": 206, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"figures['Adj Close'].rolling(len(figures), min_periods=1).mean().tail()[::-1]" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Covariance" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 192, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"┌───────┬────────────────────┐\n", | |
"│ step │ cov │\n", | |
"│ int64 │ double │\n", | |
"├───────┼────────────────────┤\n", | |
"│ 505 │ 0.9979967767965502 │\n", | |
"│ 504 │ 0.9918524780369538 │\n", | |
"│ 503 │ 0.985478504290919 │\n", | |
"│ 502 │ 0.9787158318485241 │\n", | |
"│ 501 │ 0.9719167545245742 │\n", | |
"└───────┴────────────────────┘\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"duckdb.sql('''\n", | |
"WITH RECURSIVE\n", | |
" stream AS (\n", | |
" SELECT\n", | |
" ROW_NUMBER() OVER () AS step,\n", | |
" \"Adj Close\" AS x,\n", | |
" \"Close\" AS y\n", | |
" FROM figures\n", | |
" ),\n", | |
" state(step, x, x_avg, y, y_avg, cov) AS (\n", | |
" -- Initialize\n", | |
" SELECT\n", | |
" step,\n", | |
" x,\n", | |
" x AS x_avg,\n", | |
" y,\n", | |
" y AS y_avg,\n", | |
" 0::DOUBLE AS cov\n", | |
" FROM stream\n", | |
" WHERE step = 1\n", | |
" UNION ALL\n", | |
" -- Update\n", | |
" SELECT\n", | |
" step,\n", | |
" x,\n", | |
" x_new_avg AS x_avg,\n", | |
" y,\n", | |
" y_new_avg AS y_avg,\n", | |
" cov + ((x - x_prev_avg) * (y - y_new_avg) - cov) / step AS cov\n", | |
" FROM (\n", | |
" SELECT\n", | |
" stream.step,\n", | |
" stream.x,\n", | |
" stream.y,\n", | |
" state.x_avg AS x_prev_avg,\n", | |
" state.x_avg + (stream.x - state.x_avg) / stream.step AS x_new_avg,\n", | |
" state.y_avg AS y_prev_avg,\n", | |
" state.y_avg + (stream.y - state.y_avg) / stream.step AS y_new_avg,\n", | |
" state.cov\n", | |
" FROM stream\n", | |
" INNER JOIN state ON state.step + 1 = stream.step\n", | |
" )\n", | |
" )\n", | |
"\n", | |
"SELECT step, cov\n", | |
"FROM state\n", | |
"ORDER BY step DESC\n", | |
"LIMIT 5\n", | |
"''').show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 217, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Date\n", | |
"2021-12-31 0.997997\n", | |
"2021-12-30 0.991852\n", | |
"2021-12-29 0.985479\n", | |
"2021-12-28 0.978716\n", | |
"2021-12-27 0.971917\n", | |
"Name: Adj Close, dtype: float64" | |
] | |
}, | |
"execution_count": 217, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(\n", | |
" figures\n", | |
" .rolling(len(figures), min_periods=1)\n", | |
" .cov(ddof=0)['Adj Close']\n", | |
" .loc[:, 'Close']\n", | |
" .tail()[::-1]\n", | |
")" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Handling many variables" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 221, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"| date | variable | value |\n", | |
"|:--------------------|:-----------|--------:|\n", | |
"| 2020-01-02 00:00:00 | Adj Close | 2.49235 |\n", | |
"| 2020-01-02 00:00:00 | Close | 2.55055 |\n", | |
"| 2020-01-02 00:00:00 | High | 2.54002 |\n", | |
"| 2020-01-02 00:00:00 | Low | 2.52122 |\n", | |
"| 2020-01-02 00:00:00 | Open | 2.51432 |\n", | |
"| 2020-01-02 00:00:00 | Volume | 2.14521 |\n", | |
"| 2020-01-03 00:00:00 | Adj Close | 2.46812 |\n", | |
"| 2020-01-03 00:00:00 | Close | 2.52576 |\n", | |
"| 2020-01-03 00:00:00 | High | 2.53985 |\n", | |
"| 2020-01-03 00:00:00 | Low | 2.53241 |\n" | |
] | |
} | |
], | |
"source": [ | |
"figures_flat = figures.melt(ignore_index=False).reset_index()\n", | |
"figures_flat.columns = ['date', 'variable', 'value']\n", | |
"figures_flat = figures_flat.sort_values(['date', 'variable'])\n", | |
"print(figures_flat.head(10).to_markdown(index=False))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 222, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"┌───────┬───────────┬────────────────────┬────────────────────┐\n", | |
"│ step │ variable │ value │ avg │\n", | |
"│ int64 │ varchar │ double │ double │\n", | |
"├───────┼───────────┼────────────────────┼────────────────────┤\n", | |
"│ 505 │ Adj Close │ 5.981568542028378 │ 3.9577706471349923 │\n", | |
"│ 505 │ Close │ 6.03165394229666 │ 4.012373756823449 │\n", | |
"│ 505 │ High │ 6.057853942108038 │ 4.03765319364954 │\n", | |
"│ 505 │ Low │ 6.05591789308585 │ 3.985178489614261 │\n", | |
"│ 505 │ Open │ 6.046125216781687 │ 4.006746251814558 │\n", | |
"│ 505 │ Volume │ 1.0143664144585565 │ 1.9651814487272024 │\n", | |
"└───────┴───────────┴────────────────────┴────────────────────┘\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"duckdb.sql('''\n", | |
"WITH RECURSIVE\n", | |
" stream AS (\n", | |
" SELECT RANK_DENSE() OVER (ORDER BY date) AS step, *\n", | |
" FROM figures_flat\n", | |
" ORDER BY date\n", | |
" ),\n", | |
" state(step, variable, value, avg) AS (\n", | |
" -- Initialize\n", | |
" SELECT step, variable, value, value AS avg\n", | |
" FROM stream\n", | |
" WHERE step = 1\n", | |
" UNION ALL\n", | |
" -- Update\n", | |
" SELECT\n", | |
" stream.step,\n", | |
" stream.variable,\n", | |
" stream.value,\n", | |
" state.avg + (stream.value - state.avg) / stream.step AS avg\n", | |
" FROM stream\n", | |
" INNER JOIN state ON\n", | |
" state.step + 1 = stream.step AND\n", | |
" state.variable = stream.variable\n", | |
" )\n", | |
"\n", | |
"SELECT *\n", | |
"FROM state\n", | |
"WHERE step = (SELECT MAX(step) FROM state)\n", | |
"ORDER BY variable\n", | |
"''').show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 232, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable \n", | |
"Adj Close 2524 3.957771\n", | |
"Close 2019 4.012374\n", | |
"High 1009 4.037653\n", | |
"Low 1514 3.985178\n", | |
"Open 504 4.006746\n", | |
"Volume 3029 1.965181\n", | |
"Name: value, dtype: float64" | |
] | |
}, | |
"execution_count": 232, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(\n", | |
" figures_flat\n", | |
" .groupby('variable')['value']\n", | |
" .rolling(len(figures_flat), min_periods=1)\n", | |
" .mean()\n", | |
" .groupby('variable')\n", | |
" .tail(1)[::-1].sort_index()\n", | |
")" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Stochastic gradient descent\n", | |
"\n", | |
"Vanilla SGD, meaning\n", | |
"\n", | |
"- Constant learning rate\n", | |
"- Single epoch\n", | |
"- Squared loss\n", | |
"- No gradient clipping\n", | |
"- No regularisation\n", | |
"- No intercept" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"$$p_t = \\sum_{t=1}^{n} \\dot{w}_t * \\dot{x}_t$$\n", | |
"$$l_t = p_t - y_t$$\n", | |
"$$\\dot{g}_t = l_t * \\dot{x}_t$$\n", | |
"$$\\dot{w}_{t+1} = \\dot{w}_t - \\eta \\dot{g}_t$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 234, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"┌───────┬──────────┬──────────────────────┬───────────────────┬───────────────────┐\n", | |
"│ step │ variable │ weight │ target │ prediction │\n", | |
"│ int64 │ varchar │ double │ double │ double │\n", | |
"├───────┼──────────┼──────────────────────┼───────────────────┼───────────────────┤\n", | |
"│ 505 │ Close │ 0.2511547716803354 │ 5.981568542028378 │ 5.938875441702928 │\n", | |
"│ 505 │ High │ 0.24043897039853313 │ 5.981568542028378 │ 5.938875441702928 │\n", | |
"│ 505 │ Low │ 0.2447191283620627 │ 5.981568542028378 │ 5.938875441702928 │\n", | |
"│ 505 │ Open │ 0.23603830762609726 │ 5.981568542028378 │ 5.938875441702928 │\n", | |
"│ 505 │ Volume │ 0.057510279698874206 │ 5.981568542028378 │ 5.938875441702928 │\n", | |
"└───────┴──────────┴──────────────────────┴───────────────────┴───────────────────┘" | |
] | |
}, | |
"execution_count": 234, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"duckdb.sql('''\n", | |
"WITH RECURSIVE\n", | |
" X AS (\n", | |
" SELECT\n", | |
" RANK_DENSE() OVER (ORDER BY date) AS step, *\n", | |
" FROM figures_flat\n", | |
" WHERE variable != 'Adj Close'\n", | |
" ORDER BY date\n", | |
" ),\n", | |
" y AS (\n", | |
" SELECT\n", | |
" RANK_DENSE() OVER (ORDER BY date) AS step, *\n", | |
" FROM figures_flat\n", | |
" WHERE variable = 'Adj Close'\n", | |
" ORDER BY date\n", | |
" ),\n", | |
" stream AS (\n", | |
" SELECT X.*, y.value AS target\n", | |
" FROM X\n", | |
" INNER JOIN y ON X.step = y.step\n", | |
" ),\n", | |
" state AS (\n", | |
" -- Initialize\n", | |
" SELECT\n", | |
" step,\n", | |
" target,\n", | |
" variable,\n", | |
" value,\n", | |
" 0::DOUBLE AS weight,\n", | |
" 0::DOUBLE AS prediction\n", | |
" FROM stream\n", | |
" WHERE step = 1\n", | |
" UNION ALL\n", | |
" -- Update\n", | |
" SELECT\n", | |
" step,\n", | |
" target,\n", | |
" variable,\n", | |
" value,\n", | |
" weight,\n", | |
" SUM(weight * value) OVER () AS prediction\n", | |
" FROM (\n", | |
" SELECT\n", | |
" stream.step,\n", | |
" stream.target,\n", | |
" stream.variable,\n", | |
" stream.value,\n", | |
" state.prediction - state.target AS loss_gradient,\n", | |
" loss_gradient * state.value AS gradient,\n", | |
" state.weight - 0.01 * gradient AS weight\n", | |
" FROM stream\n", | |
" INNER JOIN state ON\n", | |
" state.step + 1 = stream.step AND\n", | |
" state.variable = stream.variable\n", | |
" )\n", | |
" )\n", | |
"\n", | |
"SELECT step, variable, weight, target, prediction\n", | |
"FROM state\n", | |
"WHERE step = (SELECT MAX(step) FROM state)\n", | |
"ORDER BY variable\n", | |
"''')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 236, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'Close': 0.2511547716803354,\n", | |
" 'High': 0.2404389703985331,\n", | |
" 'Low': 0.2447191283620624,\n", | |
" 'Open': 0.23603830762609757,\n", | |
" 'Volume': 0.05751027969887417}\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/max/.pyenv/versions/3.11.0/lib/python3.11/site-packages/sklearn/linear_model/_stochastic_gradient.py:1551: ConvergenceWarning: Maximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"from pprint import pprint\n", | |
"from sklearn import linear_model\n", | |
"\n", | |
"model = linear_model.SGDRegressor(\n", | |
" loss='squared_error',\n", | |
" penalty=None,\n", | |
" fit_intercept=False,\n", | |
" learning_rate='constant',\n", | |
" eta0=0.01,\n", | |
" max_iter=1,\n", | |
" shuffle=False\n", | |
")\n", | |
"\n", | |
"X = figures[:-1].copy()\n", | |
"y = X.pop('Adj Close')\n", | |
"model = model.fit(X, y)\n", | |
"pprint(dict(zip(X.columns, model.coef_)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 237, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'Close': 0.2511547716803356,\n", | |
" 'High': 0.2404389703985331,\n", | |
" 'Low': 0.24471912836206253,\n", | |
" 'Open': 0.2360383076260972,\n", | |
" 'Volume': 0.057510279698874255}\n" | |
] | |
} | |
], | |
"source": [ | |
"from river import linear_model\n", | |
"from river import optim\n", | |
"\n", | |
"class ScikitLearnSquaredLoss:\n", | |
" \"\"\"sklearn removes the leading 2 from the gradient of the squared loss.\"\"\"\n", | |
"\n", | |
" def gradient(self, y_true, y_pred):\n", | |
" return y_pred - y_true\n", | |
"\n", | |
"model = linear_model.LinearRegression(\n", | |
" optimizer=optim.SGD(lr=0.01),\n", | |
" loss=ScikitLearnSquaredLoss(),\n", | |
" intercept_lr=0.0,\n", | |
" l2=0.0\n", | |
")\n", | |
"\n", | |
"for i, x in enumerate(figures[:-1].to_dict(orient='records')):\n", | |
" y = x.pop('Adj Close')\n", | |
" model.learn_one(x, y)\n", | |
"\n", | |
"pprint(model.weights)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.0" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment