Skip to content

Instantly share code, notes, and snippets.

@jsrimr
Created September 6, 2019 06:34
Show Gist options
  • Save jsrimr/884d97755b897e80ddf09d1b6d2ab062 to your computer and use it in GitHub Desktop.
Save jsrimr/884d97755b897e80ddf09d1b6d2ab062 to your computer and use it in GitHub Desktop.
def main():
env = TradingEnv(custom_args=args, env_id='custom_trading_env', obs_data_len=obs_data_len, step_len=step_len, sample_len=sample_len,
df=df, fee=fee, initial_budget=1, n_action_intervals=n_action_intervals, deal_col_name='c', sell_at_end=True,
feature_names=['o', 'h','l','c','v',
'num_trades', 'taker_base_vol'])
agent = dqn_agent.Agent(action_size=2 * n_action_intervals + 1, obs_len=obs_data_len, num_features=env.reset().shape[-1], **hyperparams)
agent.qnetwork_local.load_state_dict(torch.load(os.path.join(load_location, 'TradingGym_Rainbow_1000.pth'), map_location=device))
agent.qnetwork_local.to(device)
for eps in range(n_episode=500):
next_state, reward, done, _ = env.step(agent.act(state))
agent.learn()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment