Skip to content

Instantly share code, notes, and snippets.

@buswedg
Created April 12, 2020 04:07
Show Gist options
  • Save buswedg/a30331b78b47dcc1418804fae95639ef to your computer and use it in GitHub Desktop.
Save buswedg/a30331b78b47dcc1418804fae95639ef to your computer and use it in GitHub Desktop.
reinforcement_learning_for_share_trading\utils
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style('darkgrid')\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import yfinance as yf # will likely need to pip install\n",
"# note an issue with yfinance here: github.com/ranaroussi/yfinance/issues/214\n",
"\n",
"from indicators import *\n",
"from col_refs import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_price(entity_symbol, data_source='yfinance', table_type='eod_price'):\n",
" ticker = yf.Ticker(entity_symbol)\n",
" pd_price = ticker.history(period='max', auto_adjust=False, rounding=False)\n",
"\n",
" dic_cols_rename = dic_cols_rename_ref[data_source][table_type]\n",
"\n",
" if not set(pd_price.columns).issubset(dic_cols_rename.keys()):\n",
" print('WARNING: unknown columns encountered')\n",
"\n",
" pd_price = pd_price.rename(columns=dic_cols_rename)\n",
" pd_price['entity_symbol'] = entity_symbol\n",
" pd_price = pd_price.drop(['dividend', 'split'], axis=1)\n",
" pd_price.index.name = None\n",
"\n",
" return pd_price"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_indicator(pd_prices):\n",
"\n",
" pd_indicator = calc_indicators(pd_prices, col_high='high', col_low='low', col_close='close', col_volume='volume', bool_fillna=False)\n",
"\n",
" pd_indicator = pd_indicator.dropna(axis=1, how='all')\n",
"\n",
" #for c in pd_indicator.columns:\n",
" # if pd_indicator[c].isnull().sum(axis=0) / len(pd_indicator) > 0.5:\n",
" # pd_indicator = pd_indicator.drop(c, axis=1)\n",
"\n",
" pd_indicator = pd_indicator.fillna(method='ffill')\n",
" pd_indicator = pd_indicator.fillna(method='bfill')\n",
"\n",
" return pd_indicator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_portvals(pd_orders, pd_prices, sv=1000000, comm=9.95, imp=0.005):\n",
" ls_symbols = pd_prices['entity_symbol'].unique()\n",
"\n",
" sd = pd_orders.index.min()\n",
" ed = pd_orders.index.max()\n",
"\n",
" pd_prices = pd_prices.loc[sd:ed, :]\n",
"\n",
" pd_prices_all = pd.DataFrame([], index=pd_prices.index)\n",
"\n",
" for symb in ls_symbols:\n",
" pd_prices_all[symb] = pd_prices.loc[pd_prices['entity_symbol'] == symb, 'close']\n",
"\n",
" pd_account = pd.DataFrame(0, index=pd_prices_all.index, columns=['credit', 'debit', 'fees'])\n",
" pd_positions = pd.DataFrame(0, index=pd_prices_all.index, columns=pd_prices_all.columns)\n",
"\n",
" pd_account.loc[sd, 'credit'] = sv\n",
"\n",
" for index, row in pd_orders.iterrows():\n",
" symb = row['entity_symbol']\n",
" order = row['order']\n",
" shares = row['shares']\n",
"\n",
" price = pd_prices_all.loc[index, symb]\n",
"\n",
" trade_value = price * shares\n",
" trade_cost = comm + imp * trade_value\n",
"\n",
" if order == 'BUY':\n",
" pd_positions.loc[index, symb] = pd_positions.loc[index, symb] + shares\n",
" pd_account.loc[index, 'debit'] = pd_account.loc[index, 'debit'] - trade_value\n",
"\n",
" else:\n",
" pd_positions.loc[index, symb] = pd_positions.loc[index, symb] - shares\n",
" pd_account.loc[index, 'credit'] = pd_account.loc[index, 'credit'] + trade_value\n",
"\n",
" pd_account.loc[index, 'fees'] = pd_account.loc[index, 'fees'] - trade_cost\n",
"\n",
" pd_positions = pd_positions.cumsum()\n",
"\n",
" pd_account['balance'] = pd_account[['credit', 'debit', 'fees']].sum(axis=1)\n",
" pd_account['balance'] = pd_account['balance'].cumsum()\n",
"\n",
" pd_holdings = pd_positions * pd_prices_all\n",
" pd_account['holdings'] = pd_holdings.sum(axis=1)\n",
"\n",
" pd_account['value'] = pd_account[['balance', 'holdings']].sum(axis=1)\n",
"\n",
" portvals = pd.DataFrame(pd_account['value'].values, pd_account.index, ['port_val'])\n",
"\n",
" return portvals"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_portfolio_stats(port_val, rfr=0.0, sf=252.0):\n",
" dr = (port_val / port_val.shift(1)) - 1\n",
" cr = (port_val.iloc[-1] / port_val.iloc[0]) - 1\n",
"\n",
" dr.iloc[0] = 0\n",
" dr = dr[1:]\n",
" dr_rfr = dr - rfr\n",
"\n",
" adr_rfr = dr_rfr.mean()\n",
" adr = dr.mean()\n",
" sddr = dr.std()\n",
"\n",
" sr = np.sqrt(sf) * (adr_rfr / sddr)\n",
"\n",
" return cr, adr, sddr, sr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def market_simulator(pd_orders, pd_benchmark, pd_prices,\n",
" sv=1000000, comm=9.95, imp=0.005,\n",
" daily_rf=0.0, samples_per_year=252.0,\n",
" save_fig=False, gen_stats=True,\n",
" fig_name='plot.png', stats_name='stats.tsv'):\n",
"\n",
" portvals = compute_portvals(pd_orders, pd_prices, sv, comm, imp)\n",
" cr_port, adr_port, sddr_port, sr_port = compute_portfolio_stats(portvals, rfr=daily_rf, sf=samples_per_year)\n",
"\n",
" benchvals = compute_portvals(pd_benchmark, pd_prices, sv, comm, imp)\n",
" cr_bench, adr_bench, sddr_bench, sr_bench = compute_portfolio_stats(benchvals, rfr=daily_rf, sf=samples_per_year)\n",
"\n",
" pd_stats = pd.DataFrame({'cumulative_return': ['{:0.6f}'.format(cr_port[0]), '{:0.6f}'.format(cr_bench[0])],\n",
" 'average_daily_return': ['{:0.6f}'.format(adr_port[0]), '{:0.6f}'.format(adr_bench[0])],\n",
" 'std_dev_of_returns': ['{:0.6f}'.format(sddr_port[0]), '{:0.6f}'.format(sddr_bench[0])],\n",
" 'sharpe_ratio': ['{:0.4f}'.format(sr_port[0]), '{:0.4f}'.format(sr_bench[0])],\n",
" 'number_of_trades': [len(pd_orders), len(pd_benchmark)]},\n",
" columns=['cumulative_return', 'average_daily_return', 'std_dev_of_returns', 'sharpe_ratio', 'number_of_trades'],\n",
" index=['portfolio', 'benchmark'])\n",
"\n",
" if gen_stats == True:\n",
" pd_stats.to_csv(stats_name, sep='\\t')\n",
"\n",
" portvals_norm = portvals / portvals.iloc[0]\n",
" benchvals_norm = benchvals / benchvals.iloc[0]\n",
"\n",
" portvals_norm = portvals_norm.reset_index()\n",
" benchvals_norm = benchvals_norm.reset_index()\n",
"\n",
" vals_norm = pd.merge(portvals_norm, benchvals_norm, on='index', how='outer')\n",
" vals_norm = vals_norm.set_index('index').sort_index().ffill()\n",
" vals_norm.columns = ['portfolio', 'benchmark']\n",
"\n",
" fig = plt.figure(figsize=(20, 12))\n",
" ax1 = fig.add_subplot(111)\n",
"\n",
" plt.plot(vals_norm.index, vals_norm['portfolio'], color='black', label='portfolio')\n",
" plt.plot(vals_norm.index, vals_norm['benchmark'], color='blue', label='benchmark')\n",
"\n",
" for date in pd_orders.index:\n",
"\n",
" if pd_orders.loc[date, 'order'] == 'BUY':\n",
" plt.axvline(date, color='g', alpha=0.5)\n",
"\n",
" else:\n",
" plt.axvline(date, color='r', alpha=0.5)\n",
"\n",
" plt.title('strategy_learner portfolio vs. benchmark')\n",
" plt.ylabel('normalized value')\n",
" plt.legend(loc='upper right')\n",
"\n",
" if save_fig == True:\n",
" plt.savefig(fig_name, bbox_inches='tight')\n",
"\n",
" else:\n",
" plt.interactive(True)\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment