Created
April 12, 2020 04:07
-
-
Save buswedg/3d52b04a7d2d871cc56bf0850866944a to your computer and use it in GitHub Desktop.
reinforcement_learning_for_share_trading\strategy_learner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import seaborn as sns\n", | |
| "sns.set_style('darkgrid')\n", | |
| "\n", | |
| "import datetime as dt\n", | |
| "\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "from q_learner import *\n", | |
| "from indicators import *\n", | |
| "from col_refs import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class StrategyLearner(object):\n", | |
| "\n", | |
| " def __init__(self, impact=0.0, verbose=False):\n", | |
| " self.impact = impact\n", | |
| " self.verbose = verbose\n", | |
| "\n", | |
| " self.ql = QLearner(num_states=9999, num_actions=3, alpha=0.2, gamma=0.9, rar=0.98, radr=0.999, dyna=0,\n", | |
| " verbose=False)\n", | |
| "\n", | |
| " def get_bins(self, ps_feature, num_steps):\n", | |
| " step_size = int(len(ps_feature.index) / (num_steps + 1))\n", | |
| " ps_feature = ps_feature.sort_values()\n", | |
| "\n", | |
| " bins = []\n", | |
| " for s in range(0, num_steps + 1):\n", | |
| " if s == 0:\n", | |
| " bins.append(ps_feature.iloc[0])\n", | |
| "\n", | |
| " elif s < num_steps:\n", | |
| " bins.append(ps_feature.iloc[s * step_size])\n", | |
| "\n", | |
| " else:\n", | |
| " bins.append(ps_feature.iloc[-1])\n", | |
| "\n", | |
| " return bins\n", | |
| "\n", | |
| "\n", | |
| " def get_trade(self, action, holdings, trade_size):\n", | |
| "\n", | |
| " if (action == 1) and (holdings < trade_size):\n", | |
| " order = 'BUY'\n", | |
| "\n", | |
| " if holdings == 0:\n", | |
| " shares = trade_size\n", | |
| " holdings += trade_size\n", | |
| " else:\n", | |
| " shares = (2 * trade_size)\n", | |
| " holdings += (2 * trade_size)\n", | |
| "\n", | |
| " elif (action == 2) and (holdings > -trade_size):\n", | |
| " order = 'SELL'\n", | |
| "\n", | |
| " if holdings == 0:\n", | |
| " shares = trade_size\n", | |
| " holdings -= trade_size\n", | |
| " else:\n", | |
| " shares = (2 * trade_size)\n", | |
| " holdings -= (2 * trade_size)\n", | |
| "\n", | |
| " else:\n", | |
| " order = 'HOLD'\n", | |
| " shares = 0\n", | |
| "\n", | |
| " return order, shares, holdings\n", | |
| "\n", | |
| "\n", | |
| " def addEvidence(self, symbol, pd_prices, pd_features,\n", | |
| " sd=dt.datetime(2008, 1, 1), ed=dt.datetime(2009, 12, 31), sv=10000):\n", | |
| "\n", | |
| " pd_prices = pd_prices.loc[sd:ed, :]\n", | |
| " pd_features = pd_features.loc[sd:ed, :]\n", | |
| "\n", | |
| " # pd_states = pd_features[['volume_adi', 'volume_obv', 'volume_cmf', 'volume_vpt', 'volume_nvi']].copy()\n", | |
| " pd_states = pd_features[['momentum_roc', 'momentum_rsi', 'trend_macd_diff', 'volatility_bbm']].copy()\n", | |
| "\n", | |
| " num_steps = 9 # note num_states for ql\n", | |
| "\n", | |
| " for c in pd_states.columns:\n", | |
| " ls_bins = self.get_bins(pd_states[c], num_steps)\n", | |
| " pd_states[c] = pd.cut(pd_states[c], bins=ls_bins, labels=range(0, num_steps)).fillna(0) # duplicates='drop'\n", | |
| "\n", | |
| " pd_states['state'] = pd_states.applymap(str).sum(axis=1).astype(int)\n", | |
| "\n", | |
| " sym_prices = pd_prices['adj_close']\n", | |
| " sym_returns = (sym_prices[1:] / sym_prices[:-1].values) - 1\n", | |
| "\n", | |
| " ps_symbol = pd.DataFrame(symbol, index=pd_prices.index, columns=['entity_symbol'])\n", | |
| " ps_order = pd.DataFrame('HOLD', index=pd_prices.index, columns=['order'])\n", | |
| " ps_shares = pd.DataFrame(0, index=pd_prices.index, columns=['shares'])\n", | |
| "\n", | |
| " pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n", | |
| " pd_trades.columns = ['entity_symbol', 'order', 'shares']\n", | |
| "\n", | |
| " initial_state = pd_states['state'].iloc[0]\n", | |
| " self.ql.querysetstate(initial_state)\n", | |
| "\n", | |
| " trade_size = sv / pd_prices['adj_close'].iloc[0]\n", | |
| "\n", | |
| " pd_trades_copy = pd_trades.copy()\n", | |
| "\n", | |
| " i = 0\n", | |
| " j = 0\n", | |
| "\n", | |
| " min_epoch = 20\n", | |
| " cov_epoch = 40\n", | |
| " max_epoch = 500\n", | |
| "\n", | |
| " while i < max_epoch:\n", | |
| "\n", | |
| " i += 1\n", | |
| " holdings = 0\n", | |
| "\n", | |
| " if pd_trades.equals(pd_trades_copy):\n", | |
| " if i > min_epoch:\n", | |
| " j += 1\n", | |
| "\n", | |
| " if j > cov_epoch:\n", | |
| " break\n", | |
| "\n", | |
| " pd_trades_copy = pd_trades.copy()\n", | |
| "\n", | |
| " for index, row in pd_prices[1:].iterrows():\n", | |
| " state = pd_states.loc[index, 'state']\n", | |
| " reward = holdings * sym_returns.loc[index] * (1 - self.impact)\n", | |
| " action = self.ql.query(state, reward)\n", | |
| "\n", | |
| " order, shares, holdings = self.get_trade(action, holdings, trade_size)\n", | |
| "\n", | |
| " ps_order.loc[index]['order'] = order\n", | |
| " ps_shares.loc[index]['shares'] = shares\n", | |
| "\n", | |
| " pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n", | |
| " pd_trades.columns = ['entity_symbol', 'order', 'shares']\n", | |
| "\n", | |
| " pd_trades = pd_trades.loc[pd_trades['shares'] != 0, :]\n", | |
| "\n", | |
| " return pd_trades\n", | |
| "\n", | |
| "\n", | |
| " def testPolicy(self, symbol, pd_prices, pd_features,\n", | |
| " sd=dt.datetime(2010, 1, 1), ed=dt.datetime(2011, 12, 31), sv=10000):\n", | |
| "\n", | |
| " pd_prices = pd_prices.loc[sd:ed]\n", | |
| " pd_features = pd_features.loc[sd:ed]\n", | |
| "\n", | |
| " # pd_states = pd_features[['volume_adi', 'volume_obv', 'volume_cmf', 'volume_vpt', 'volume_nvi']].copy()\n", | |
| " pd_states = pd_features[['momentum_roc', 'momentum_rsi', 'trend_macd', 'volatility_bbm']].copy()\n", | |
| "\n", | |
| " num_steps = 5\n", | |
| "\n", | |
| " for c in pd_states.columns:\n", | |
| " ls_bins = self.get_bins(pd_states[c], num_steps)\n", | |
| " pd_states[c] = pd.cut(pd_states[c], bins=ls_bins, labels=range(0, num_steps)).fillna(0) # duplicates='drop'\n", | |
| "\n", | |
| " pd_states['state'] = pd_states.applymap(str).sum(axis=1).astype(int)\n", | |
| "\n", | |
| " sym_prices = pd_prices['adj_close']\n", | |
| " sym_returns = (sym_prices[1:] / sym_prices[:-1].values) - 1\n", | |
| "\n", | |
| " ps_symbol = pd.DataFrame(symbol, index=pd_prices.index, columns=['entity_symbol'])\n", | |
| " ps_order = pd.DataFrame('HOLD', index=pd_prices.index, columns=['order'])\n", | |
| " ps_shares = pd.DataFrame(0, index=pd_prices.index, columns=['shares'])\n", | |
| "\n", | |
| " pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n", | |
| " pd_trades.columns = ['entity_symbol', 'order', 'shares']\n", | |
| "\n", | |
| " initial_state = pd_states['state'].iloc[0]\n", | |
| " self.ql.querysetstate(initial_state)\n", | |
| "\n", | |
| " trade_size = sv / pd_prices['adj_close'].iloc[0]\n", | |
| "\n", | |
| " holdings = 0\n", | |
| "\n", | |
| " for index, row in pd_prices[1:].iterrows():\n", | |
| " state = pd_states.loc[index, 'state']\n", | |
| " reward = holdings * sym_returns.loc[index] * (1 - self.impact)\n", | |
| " action = self.ql.query(state, reward)\n", | |
| "\n", | |
| " order, shares, holdings = self.get_trade(action, holdings, trade_size)\n", | |
| "\n", | |
| " ps_order.loc[index]['order'] = order\n", | |
| " ps_shares.loc[index]['shares'] = shares\n", | |
| "\n", | |
| " pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n", | |
| " pd_trades.columns = ['entity_symbol', 'order', 'shares']\n", | |
| "\n", | |
| " pd_trades = pd_trades.loc[pd_trades['shares'] != 0, :]\n", | |
| "\n", | |
| " return pd_trades" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment