Skip to content

Instantly share code, notes, and snippets.

@buswedg
Created April 12, 2020 04:07
Show Gist options
  • Save buswedg/3d52b04a7d2d871cc56bf0850866944a to your computer and use it in GitHub Desktop.
Save buswedg/3d52b04a7d2d871cc56bf0850866944a to your computer and use it in GitHub Desktop.
reinforcement_learning_for_share_trading\strategy_learner
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"sns.set_style('darkgrid')\n",
"\n",
"import datetime as dt\n",
"\n",
"import pandas as pd\n",
"\n",
"from q_learner import *\n",
"from indicators import *\n",
"from col_refs import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class StrategyLearner(object):\n",
"\n",
" def __init__(self, impact=0.0, verbose=False):\n",
" self.impact = impact\n",
" self.verbose = verbose\n",
"\n",
" self.ql = QLearner(num_states=9999, num_actions=3, alpha=0.2, gamma=0.9, rar=0.98, radr=0.999, dyna=0,\n",
" verbose=False)\n",
"\n",
" def get_bins(self, ps_feature, num_steps):\n",
" step_size = int(len(ps_feature.index) / (num_steps + 1))\n",
" ps_feature = ps_feature.sort_values()\n",
"\n",
" bins = []\n",
" for s in range(0, num_steps + 1):\n",
" if s == 0:\n",
" bins.append(ps_feature.iloc[0])\n",
"\n",
" elif s < num_steps:\n",
" bins.append(ps_feature.iloc[s * step_size])\n",
"\n",
" else:\n",
" bins.append(ps_feature.iloc[-1])\n",
"\n",
" return bins\n",
"\n",
"\n",
" def get_trade(self, action, holdings, trade_size):\n",
"\n",
" if (action == 1) and (holdings < trade_size):\n",
" order = 'BUY'\n",
"\n",
" if holdings == 0:\n",
" shares = trade_size\n",
" holdings += trade_size\n",
" else:\n",
" shares = (2 * trade_size)\n",
" holdings += (2 * trade_size)\n",
"\n",
" elif (action == 2) and (holdings > -trade_size):\n",
" order = 'SELL'\n",
"\n",
" if holdings == 0:\n",
" shares = trade_size\n",
" holdings -= trade_size\n",
" else:\n",
" shares = (2 * trade_size)\n",
" holdings -= (2 * trade_size)\n",
"\n",
" else:\n",
" order = 'HOLD'\n",
" shares = 0\n",
"\n",
" return order, shares, holdings\n",
"\n",
"\n",
" def addEvidence(self, symbol, pd_prices, pd_features,\n",
" sd=dt.datetime(2008, 1, 1), ed=dt.datetime(2009, 12, 31), sv=10000):\n",
"\n",
" pd_prices = pd_prices.loc[sd:ed, :]\n",
" pd_features = pd_features.loc[sd:ed, :]\n",
"\n",
" # pd_states = pd_features[['volume_adi', 'volume_obv', 'volume_cmf', 'volume_vpt', 'volume_nvi']].copy()\n",
" pd_states = pd_features[['momentum_roc', 'momentum_rsi', 'trend_macd_diff', 'volatility_bbm']].copy()\n",
"\n",
" num_steps = 9 # note num_states for ql\n",
"\n",
" for c in pd_states.columns:\n",
" ls_bins = self.get_bins(pd_states[c], num_steps)\n",
" pd_states[c] = pd.cut(pd_states[c], bins=ls_bins, labels=range(0, num_steps)).fillna(0) # duplicates='drop'\n",
"\n",
" pd_states['state'] = pd_states.applymap(str).sum(axis=1).astype(int)\n",
"\n",
" sym_prices = pd_prices['adj_close']\n",
" sym_returns = (sym_prices[1:] / sym_prices[:-1].values) - 1\n",
"\n",
" ps_symbol = pd.DataFrame(symbol, index=pd_prices.index, columns=['entity_symbol'])\n",
" ps_order = pd.DataFrame('HOLD', index=pd_prices.index, columns=['order'])\n",
" ps_shares = pd.DataFrame(0, index=pd_prices.index, columns=['shares'])\n",
"\n",
" pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n",
" pd_trades.columns = ['entity_symbol', 'order', 'shares']\n",
"\n",
" initial_state = pd_states['state'].iloc[0]\n",
" self.ql.querysetstate(initial_state)\n",
"\n",
" trade_size = sv / pd_prices['adj_close'].iloc[0]\n",
"\n",
" pd_trades_copy = pd_trades.copy()\n",
"\n",
" i = 0\n",
" j = 0\n",
"\n",
" min_epoch = 20\n",
" cov_epoch = 40\n",
" max_epoch = 500\n",
"\n",
" while i < max_epoch:\n",
"\n",
" i += 1\n",
" holdings = 0\n",
"\n",
" if pd_trades.equals(pd_trades_copy):\n",
" if i > min_epoch:\n",
" j += 1\n",
"\n",
" if j > cov_epoch:\n",
" break\n",
"\n",
" pd_trades_copy = pd_trades.copy()\n",
"\n",
" for index, row in pd_prices[1:].iterrows():\n",
" state = pd_states.loc[index, 'state']\n",
" reward = holdings * sym_returns.loc[index] * (1 - self.impact)\n",
" action = self.ql.query(state, reward)\n",
"\n",
" order, shares, holdings = self.get_trade(action, holdings, trade_size)\n",
"\n",
" ps_order.loc[index]['order'] = order\n",
" ps_shares.loc[index]['shares'] = shares\n",
"\n",
" pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n",
" pd_trades.columns = ['entity_symbol', 'order', 'shares']\n",
"\n",
" pd_trades = pd_trades.loc[pd_trades['shares'] != 0, :]\n",
"\n",
" return pd_trades\n",
"\n",
"\n",
" def testPolicy(self, symbol, pd_prices, pd_features,\n",
" sd=dt.datetime(2010, 1, 1), ed=dt.datetime(2011, 12, 31), sv=10000):\n",
"\n",
" pd_prices = pd_prices.loc[sd:ed]\n",
" pd_features = pd_features.loc[sd:ed]\n",
"\n",
" # pd_states = pd_features[['volume_adi', 'volume_obv', 'volume_cmf', 'volume_vpt', 'volume_nvi']].copy()\n",
" pd_states = pd_features[['momentum_roc', 'momentum_rsi', 'trend_macd', 'volatility_bbm']].copy()\n",
"\n",
" num_steps = 5\n",
"\n",
" for c in pd_states.columns:\n",
" ls_bins = self.get_bins(pd_states[c], num_steps)\n",
" pd_states[c] = pd.cut(pd_states[c], bins=ls_bins, labels=range(0, num_steps)).fillna(0) # duplicates='drop'\n",
"\n",
" pd_states['state'] = pd_states.applymap(str).sum(axis=1).astype(int)\n",
"\n",
" sym_prices = pd_prices['adj_close']\n",
" sym_returns = (sym_prices[1:] / sym_prices[:-1].values) - 1\n",
"\n",
" ps_symbol = pd.DataFrame(symbol, index=pd_prices.index, columns=['entity_symbol'])\n",
" ps_order = pd.DataFrame('HOLD', index=pd_prices.index, columns=['order'])\n",
" ps_shares = pd.DataFrame(0, index=pd_prices.index, columns=['shares'])\n",
"\n",
" pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n",
" pd_trades.columns = ['entity_symbol', 'order', 'shares']\n",
"\n",
" initial_state = pd_states['state'].iloc[0]\n",
" self.ql.querysetstate(initial_state)\n",
"\n",
" trade_size = sv / pd_prices['adj_close'].iloc[0]\n",
"\n",
" holdings = 0\n",
"\n",
" for index, row in pd_prices[1:].iterrows():\n",
" state = pd_states.loc[index, 'state']\n",
" reward = holdings * sym_returns.loc[index] * (1 - self.impact)\n",
" action = self.ql.query(state, reward)\n",
"\n",
" order, shares, holdings = self.get_trade(action, holdings, trade_size)\n",
"\n",
" ps_order.loc[index]['order'] = order\n",
" ps_shares.loc[index]['shares'] = shares\n",
"\n",
" pd_trades = pd.concat([ps_symbol, ps_order, ps_shares], axis=1)\n",
" pd_trades.columns = ['entity_symbol', 'order', 'shares']\n",
"\n",
" pd_trades = pd_trades.loc[pd_trades['shares'] != 0, :]\n",
"\n",
" return pd_trades"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment