This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
import pandas as pd | |
import trading_env | |
df = pd.read_hdf('dataset/SGXTW.h5', 'STW') | |
env = trading_env.make(env_id='training_v1', obs_data_len=256, step_len=128, | |
df=df, fee=0.1, max_position=5, deal_col_name='Price', | |
feature_names=['Price', 'Volume', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Agent(): | |
def __init__(self, risk_aversion, **args): | |
... | |
def model(): | |
... | |
def act(self, state, eps=0.): | |
... | |
return model(state) | |
def learn(self, experiences, is_weights, gamma): | |
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TradingEnv: | |
def _long(self,): # buy | |
... | |
def _long_cover(self, current_price_mean, current_mkt_position, action): # sell possession | |
... | |
def step(self, action): | |
... | |
# process buy and sell action | |
# update position of the agent | |
# return next_state and reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_stochastic(df, n=15, m=5, t=3): | |
# highest price during n days | |
ndays_high = df.h.rolling(window=n, min_periods=1).max() | |
# lowest price during n days | |
ndays_low = df.l.rolling(window=n, min_periods=1).min() | |
# Fast%K | |
kdj_k = ((df.c - ndays_low) / (ndays_high - ndays_low)) | |
# Fast%D (=Slow%K) | |
kdj_d = kdj_k.ewm(span=m).mean() | |
# Slow%D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fnRSI(m_Df, m_N=15): | |
m_Df = m_Df.c | |
U = np.where(m_Df.diff(1) > 0, m_Df.diff(1), 0) | |
D = np.where(m_Df.diff(1) < 0, m_Df.diff(1) *(-1), 0) | |
AU = pd.DataFrame(U).rolling( window=m_N, min_periods=m_N).mean() | |
AD = pd.DataFrame(D).rolling( window=m_N, min_periods=m_N).mean() | |
RSI = AU.div(AD+AU)[0].mean() | |
return RSI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_bollinger_diffs(df, n=20, k=2): | |
ma_n = df['c'].rolling(n).mean() | |
Bol_upper = df['c'].rolling(n).mean() + k* df['c'].rolling(n).std() | |
Bol_lower = df['c'].rolling(n).mean() - k* df['c'].rolling(n).std() | |
return (Bol_upper - Bol_lower).mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fnMACD(m_Df, m_NumFast=12, m_NumSlow=26, m_NumSignal=9): | |
EMAFast = m_Df['c'].ewm( span = m_NumFast, min_periods = m_NumFast - 1).mean() | |
EMASlow = m_Df['c'].ewm( span = m_NumSlow, min_periods = m_NumSlow - 1).mean() | |
MACD = EMAFast - EMASlow | |
MACDSignal= MACD.ewm( span = m_NumSignal, min_periods = m_NumSignal-1).mean() | |
MACDDiff= MACD - MACDSignal | |
return MACDDiff.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def step(self): | |
... | |
derivative_diff = self.get_derivative_diffs(self.df_sample.iloc[self.step_st: self.step_st + self.obs_len]) | |
self.fee_rate = np.clip( self.fee_rate * derivative_diff / self.previous_diff, self.min_fee_rate, self.max_fee_rate) | |
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
env = TradingEnv(custom_args=args, env_id='custom_trading_env', obs_data_len=obs_data_len, step_len=step_len, sample_len=sample_len, | |
df=df, fee=fee, initial_budget=1, n_action_intervals=n_action_intervals, deal_col_name='c', sell_at_end=True, | |
feature_names=['o', 'h','l','c','v', | |
'num_trades', 'taker_base_vol']) | |
agent = dqn_agent.Agent(action_size=2 * n_action_intervals + 1, obs_len=obs_data_len, num_features=env.reset().shape[-1], **hyperparams) | |
agent.qnetwork_local.load_state_dict(torch.load(os.path.join(load_location, 'TradingGym_Rainbow_1000.pth'), map_location=device)) | |
agent.qnetwork_local.to(device) | |
for eps in range(n_episode=500): | |
next_state, reward, done, _ = env.step(agent.act(state)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fn = {'stochastic': get_stochastic, 'rsi': fnRSI,'macd': fnMACD, 'bollinger': get_bollinger_diffs} | |
self.get_derivative_diffs = fn.get(custom_args.environment) |
OlderNewer