Last active
February 18, 2024 06:07
-
-
Save danielhamelberg/b212339493ee0a5ffd3ee2b42c9d05c2 to your computer and use it in GitHub Desktop.
ppo_dqn_v4.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from python_bitvavo_api.bitvavo import Bitvavo | |
| import torch | |
| from torch.utils.data import Dataset | |
| from pathlib import Path | |
| import copy | |
| import gymnasium as gym | |
| import numpy as np | |
| import pandas as pd | |
| from gymnasium import spaces | |
| from pathlib import Path | |
| import yaml | |
| import torch | |
| import time | |
| from torch.utils.data import DataLoader | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| class CandlesDataset(Dataset): | |
| def __init__(self, df): | |
| df[['open', 'high', 'low', 'close', 'volume']] = df[ | |
| ['open', 'high', 'low', 'close', 'volume'] | |
| ].apply(pd.to_numeric, errors='coerce') | |
| df.dropna(inplace=True) | |
| self.features = torch.tensor( | |
| df[['open', 'high', 'low', 'close', 'volume']].astype( | |
| float).values, | |
| dtype=torch.float32 | |
| ) | |
| self.labels = torch.tensor( | |
| df['close'].astype(float).values, | |
| dtype=torch.float32 | |
| ) | |
| def __len__(self): | |
| return len(self.features) | |
| def __getitem__(self, idx): | |
| print(f"Accessing index: {idx}") | |
| return {'features': self.features[idx], 'label': self.labels[idx]} | |
| class CandlesDataset(Dataset): | |
| def __init__(self, df): | |
| self.df = df[['open', 'high', 'low', 'close', 'volume']].astype(float) | |
| self.labels = df['close'].astype(float) | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| features = self.df.iloc[idx].values | |
| label = self.labels.iloc[idx] | |
| return {'features': features, 'label': label} | |
| class HistoricalData: | |
| def __init__(self, market, interval, limit): | |
| self.market = market | |
| self.interval = interval | |
| self.limit = limit | |
| self.api_key, self.api_secret = self._load_config() | |
| self.bitvavo = Bitvavo( | |
| {'APIKEY': self.api_key, 'APISECRET': self.api_secret}) | |
| self.bitvavo_data = self.bitvavo_data() | |
| self.account_info = self.account_info() | |
| self.market_df = self.market_df() | |
| self.state = self.get_state() | |
| self.next_state = self.get_next_state() | |
| self.dataset = CandlesDataset(self.market_df) | |
| def bitvavo_data(self): | |
| candles = self.bitvavo.candles( | |
| self.market, self.interval, {'limit': self.limit}) | |
| return candles | |
| def account_info(self): | |
| return self.bitvavo.balance({}) | |
| def get_state(self): | |
| try: | |
| logging.debug(f"Market data: {self.market_data}") | |
| logging.debug(f"Account info: {self.account_info}") | |
| state = { | |
| 'market_data': self.bitvavo_data, | |
| 'account_info': self.account_info | |
| } | |
| return state | |
| except Exception as e: | |
| logging.error(f"An error occurred during get_state: {e}") | |
| raise e | |
| def get_next_state(self): | |
| # This should be implemented to return the next state -- meaning the next row in the market data: | |
| next_state = self.market_data[0] | |
| return next_state | |
| def market_df(self): | |
| df = pd.DataFrame(self.bitvavo_data, columns=[ | |
| 'timestamp', 'open', 'high', 'low', 'close', 'volume']) | |
| df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') | |
| df.set_index('timestamp', drop=True, inplace=True) | |
| df.sort_index(inplace=True) | |
| return df | |
| def _load_config(self): | |
| config_path = Path(__file__).parents[1] / "bitvavo.yaml" | |
| try: | |
| with open(config_path, "r") as file: | |
| config = yaml.safe_load(file) | |
| return config['bitvavo']['api_key'], config['bitvavo']['api_secret'] | |
| except FileNotFoundError as e: | |
| logging.error( | |
| f"Tried to load the Bitvavo API key and secret from {config_path}, but the file was not found.") | |
| raise e | |
| class CryptoTradingEnv(gym.Env): | |
| def __init__(self, market, interval, limit, max_steps, min_balance, profit_target): | |
| self.historical_data = HistoricalData(market, interval, limit) | |
| self.dataset = CandlesDataset(self.historical_data) | |
| self.state_space = spaces.Box( | |
| low=0, high=np.inf, shape=(limit + 2,), dtype=np.float32) | |
| self.action_space = spaces.Discrete(7) | |
| self.observation_space = spaces.Tuple( | |
| (self.state_space, self.action_space)) | |
| self.reward_range = (-np.inf, np.inf) | |
| self.market = market | |
| self.interval = interval | |
| self.limit = limit | |
| self.max_steps = max_steps | |
| self.min_balance = min_balance | |
| self.profit_target = profit_target | |
| self.initial_balance = 1000 | |
| self.current_balance = self.initial_balance | |
| self.holdings = np.zeros(3) | |
| self.prices = np.zeros(3) | |
| self.current_step = 0 | |
| self.done = False | |
| self.info = {} | |
| self.reward = 0 | |
| self.action = 0 | |
| self.obs = self.historical_data.get_state() | |
| def reset(self): | |
| logging.debug("Resetting the environment") | |
| self.current_balance = self.initial_balance | |
| self.holdings = np.zeros(3) | |
| self.prices = np.zeros(3) | |
| self.current_step = 0 | |
| self.done = False | |
| self.info = {} | |
| self.reward = 0 | |
| self.action = 0 | |
| self.obs = self.historical_data.get_state() | |
| logging.debug(f"Initial observation: {self.obs}") | |
| return self.obs | |
| def step(self, action): | |
| logging.debug(f"Taking action: {action}") | |
| assert self.action_space.contains(action), f"Invalid action: {action}" | |
| self.action = action | |
| self._trade() | |
| self.obs = self.historical_data.get_next_state() | |
| self.reward = self._calculate_reward() | |
| self.done = self._check_termination() | |
| self.info = {'balance': self.current_balance, 'holdings': self.holdings, | |
| 'prices': self.prices, 'pnl': self.current_balance - self.initial_balance} | |
| self.current_step += 1 | |
| logging.debug(f"Observation: {self.obs}") | |
| return self.obs, self.reward, self.done, self.info | |
| def render(self, mode='human'): | |
| print(f"Step: {self.current_step}") | |
| print(f"Action: {self.action}") | |
| print(f"Balance: {self.current_balance}") | |
| print(f"Holdings: {self.holdings}") | |
| print(f"Prices: {self.prices}") | |
| print(f"Reward: {self.reward}") | |
| def close(self): | |
| pass | |
| def _trade(self): | |
| self.prices = self.obs[-3:] | |
| action = self.action | |
| amount = self.current_balance * 0.1 | |
| trade_actions = [ | |
| lambda: None, | |
| lambda: self._buy(0, amount), | |
| lambda: self._sell(0, amount), | |
| lambda: self._buy(1, amount), | |
| lambda: self._sell(1, amount), | |
| lambda: self._buy(2, amount), | |
| lambda: self._sell(2, amount) | |
| ] | |
| trade_actions[action]() | |
| def _buy(self, index, amount): | |
| if amount > self.current_balance: | |
| amount = self.current_balance | |
| num_coins = amount / self.prices[index] | |
| self.current_balance -= amount | |
| self.holdings[index] += num_coins | |
| def _sell(self, index, amount): | |
| num_coins = amount / self.prices[index] | |
| if num_coins > self.holdings[index]: | |
| num_coins = self.holdings[index] | |
| amount = num_coins * self.prices[index] | |
| self.current_balance += amount | |
| self.holdings[index] -= num_coins | |
| def _calculate_reward(self): | |
| pnl = self.current_balance - self.initial_balance | |
| sharpe_ratio = self._calculate_sharpe_ratio() | |
| transaction_costs = self._calculate_transaction_costs() | |
| reward = pnl + sharpe_ratio - transaction_costs | |
| return reward | |
| def _calculate_sharpe_ratio(self): | |
| returns = np.diff(self.obs[:, -1]) / self.obs[:-1, -1] | |
| mean = np.mean(returns) | |
| std = np.std(returns) | |
| epsilon = 1e-8 # Small epsilon value to prevent division by zero | |
| sharpe_ratio = mean / (std + epsilon) | |
| return sharpe_ratio | |
| def _calculate_transaction_costs(self): | |
| transaction_costs = self.current_balance * 0.001 | |
| return transaction_costs | |
| def _check_termination(self): | |
| if self.current_step > self.max_steps: | |
| return True | |
| if self.current_balance < self.min_balance: | |
| return True | |
| def render(self, mode='human'): | |
| pass | |
| class CustomCallback(torch.nn.Module): | |
| def __init__(self, agent): | |
| super().__init__() | |
| self.agent = agent | |
| def on_epoch_end(self, epoch, logs=None): | |
| pass | |
| class PPOAgent: | |
| def __init__(self, state_space_dim, action_space_size, gamma=0.99, tau=0.95, clip_range=0.2, batch_size=32, n_epochs=10, update_interval=100): | |
| self.state_space_dim = state_space_dim | |
| self.action_space_size = action_space_size | |
| self.gamma = gamma | |
| self.tau = tau | |
| self.clip_range = clip_range | |
| self.batch_size = batch_size | |
| self.n_epochs = n_epochs | |
| self.update_interval = update_interval | |
| self.policy, self.Q_function = self.initialize_models() | |
| self.target_Q_function = copy.deepcopy(self.Q_function) | |
| self._max_episode_steps = None | |
| self._elapsed_steps = 0 | |
| self._episode_started_at = None | |
| self._episode_ended_at = None | |
| self._episode_reward = 0 | |
| self._episode_step = 0 | |
| self._episode_info = None | |
| self._episode_done = False | |
| self._episode_state = None | |
| self._episode_action = None | |
| self._episode_next_state = None | |
| self._episode_reward = None | |
| def initialize_models(self): | |
| """ | |
| Initialize the policy and Q-function models. | |
| Returns: | |
| policy (nn.Module): Initialized policy model. | |
| Q_function (nn.Module): Initialized Q-function model. | |
| """ | |
| policy = nn.Sequential( | |
| nn.Linear(self.state_space_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, self.action_space_size), | |
| nn.Softmax(dim=-1) | |
| ) | |
| Q_function = nn.Sequential( | |
| nn.Linear(self.state_space_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, self.action_space_size) | |
| ) | |
| return policy, Q_function | |
| def save_models(self, path): | |
| torch.save(self.policy.state_dict(), f"{path}/policy.pt") | |
| torch.save(self.Q_function.state_dict(), f"{path}/Q_function.pt") | |
| def ppo_loss(self, y_true, y_pred): | |
| """ | |
| Compute the PPO loss. | |
| Args: | |
| y_true (torch.Tensor): True labels. | |
| y_pred (torch.Tensor): Predicted labels. | |
| Returns: | |
| loss (torch.Tensor): PPO loss. | |
| """ | |
| actions, advantages, old_probs_ph = y_true[:, | |
| :self.action_space_size], y_true[:, self.action_space_size:-1], y_true[:, -1] | |
| probs = torch.sum(y_pred * actions, dim=-1) | |
| new_log_probs = torch.log(probs + 1e-10) | |
| old_log_probs = torch.log(old_probs_ph + 1e-10) | |
| ratio = torch.exp(new_log_probs - old_log_probs) | |
| if torch.numel(advantages) > 0: | |
| surr1 = ratio[:, None] * advantages.unsqueeze(-1) | |
| surr2 = torch.clamp(ratio, 1 - self.clip_range, 1 + | |
| self.clip_range)[:, None] * advantages.unsqueeze(-1) | |
| loss = -torch.mean(torch.minimum(surr1, surr2)) | |
| else: | |
| loss = torch.zeros_like(ratio) | |
| logging.error("Advantages tensor is empty. Returning zero loss.") | |
| return loss | |
| def collect_trajectories(self, env, buffer): | |
| """ | |
| Collect trajectories by interacting with the environment. | |
| Args: | |
| env: Environment to interact with. | |
| buffer (dict): Buffer to store the collected trajectories. | |
| """ | |
| state = env.reset() | |
| state = torch.tensor(state) | |
| self._episode_started_at = time.time() | |
| self._episode_ended_at = None | |
| self._episode_reward = 0 | |
| self._episode_step = 0 | |
| self._episode_info = None | |
| self._episode_done = False | |
| self._episode_state = None | |
| self._episode_action = None | |
| self._episode_next_state = None | |
| self._episode_reward = None | |
| while len(buffer) < self.batch_size: | |
| try: | |
| action_probs = self.policy(state.unsqueeze(0).float())[0] | |
| if torch.isnan(action_probs).any(): | |
| raise ValueError( | |
| "NaN values detected in action probabilities.") | |
| action = torch.multinomial(action_probs, num_samples=1).item() | |
| next_state, reward, done, _ = env.step(action) | |
| buffer['states'].append(state) | |
| buffer['actions'].append(action) | |
| buffer['rewards'].append(reward) | |
| buffer['next_states'].append(next_state) | |
| buffer['dones'].append(done) | |
| self._episode_state = state | |
| self._episode_action = action | |
| self._episode_next_state = next_state | |
| self._episode_reward += reward | |
| self._episode_step += 1 | |
| self._elapsed_steps += 1 | |
| state = torch.tensor(next_state) | |
| if done or self._elapsed_steps >= self._max_episode_steps: | |
| self._episode_ended_at = time.time() | |
| self._episode_done = True | |
| state = torch.tensor(env.reset()) | |
| if self._episode_done: | |
| self._episode_reward = self._episode_reward | |
| if self._elapsed_steps % self.update_interval == 0: | |
| self.target_Q_function.load_state_dict( | |
| self.Q_function.state_dict()) | |
| except Exception as e: | |
| logging.error( | |
| f"An error occurred during collect_trajectories: {e}") | |
| state = torch.tensor(env.reset()) | |
| continue | |
| if self._elapsed_steps % self.update_interval == 0: | |
| self.target_Q_function.load_state_dict( | |
| self.Q_function.state_dict()) | |
| def compute_advantages(self, buffer): | |
| """ | |
| Compute advantages for each state-action pair in the buffer. | |
| Args: | |
| buffer (dict): Buffer containing the collected trajectories. | |
| Returns: | |
| advantages (torch.Tensor): Computed advantages. | |
| """ | |
| states = torch.stack(buffer['states']) | |
| actions = torch.tensor(buffer['actions']) | |
| rewards = torch.tensor(buffer['rewards']) | |
| next_states = torch.stack(buffer['next_states']) | |
| dones = torch.tensor(buffer['dones']) | |
| advantages = torch.zeros_like(rewards) | |
| last_advantage = 0 | |
| try: | |
| values = self.Q_function(states) | |
| next_values = self.Q_function(next_states) | |
| for t in reversed(range(len(states))): | |
| value = values[t, actions[t]] | |
| next_value = next_values[t, actions[t]] | |
| delta = rewards[t] + self.gamma * \ | |
| next_value * (1 - dones[t]) - value | |
| advantages[t] = delta + self.gamma * \ | |
| self.tau * last_advantage * (1 - dones[t]) | |
| last_advantage = advantages[t] | |
| except Exception as e: | |
| logging.error(f"An error occurred during compute_advantages: {e}") | |
| return advantages | |
| def update_policy(self, buffer): | |
| """ | |
| Update the policy using the collected trajectories. | |
| Args: | |
| buffer (dict): Buffer containing the collected trajectories. | |
| """ | |
| states = torch.stack(buffer['states']) | |
| actions = torch.tensor(buffer['actions']) | |
| advantages = self.compute_advantages(buffer) | |
| advantages = (advantages - advantages.mean()) / \ | |
| (advantages.std() + 1e-8) | |
| old_log_probs = torch.log(torch.stack(buffer['old_probs']) + 1e-10) | |
| try: | |
| optimizer = optim.Adam(self.policy.parameters()) | |
| for _ in range(self.n_epochs): | |
| optimizer.zero_grad() | |
| action_probs = self.policy(states) | |
| new_log_probs = torch.log( | |
| torch.sum(action_probs * actions.unsqueeze(-1), dim=-1) + 1e-10) | |
| ratio = torch.exp(new_log_probs - old_log_probs) | |
| if torch.numel(advantages) > 0: | |
| surr1 = ratio.unsqueeze(-1) * advantages.unsqueeze(-1) | |
| surr2 = torch.clamp( | |
| ratio, 1 - self.clip_range, 1 + self.clip_range).unsqueeze(-1) * advantages.unsqueeze(-1) | |
| loss = -torch.mean(torch.minimum(surr1, surr2)) | |
| else: | |
| loss = torch.zeros_like(ratio) | |
| logging.error( | |
| "Advantages tensor is empty. Returning zero loss.") | |
| loss.backward() | |
| optimizer.step() | |
| except Exception as e: | |
| logging.error(f"An error occurred during update_policy: {e}") | |
| def update_Q_function(self, buffer): | |
| """ | |
| Update the Q-function using the collected trajectories. | |
| Args: | |
| buffer (dict): Buffer containing the collected trajectories. | |
| """ | |
| states = torch.stack(buffer['states']) | |
| actions = torch.tensor(buffer['actions']) | |
| rewards = torch.tensor(buffer['rewards']) | |
| next_states = torch.stack(buffer['next_states']) | |
| dones = torch.tensor(buffer['dones']) | |
| target_values = self.target_Q_function(next_states) | |
| max_target_values = target_values.max(dim=1)[0] | |
| target_Q_values = rewards + self.gamma * \ | |
| max_target_values * (1 - dones) | |
| current_Q_values = self.Q_function(states) | |
| current_Q_values[torch.arange(len(actions)), actions] = target_Q_values | |
| try: | |
| optimizer = optim.Adam(self.Q_function.parameters()) | |
| for _ in range(self.n_epochs): | |
| optimizer.zero_grad() | |
| loss = nn.MSELoss()(current_Q_values, self.Q_function(states)) | |
| loss.backward() | |
| optimizer.step() | |
| except Exception as e: | |
| logging.error(f"An error occurred during update_Q_function: {e}") | |
| def train(self, env, max_iterations, target_update_interval): | |
| """ | |
| Train the agent using the PPO algorithm. | |
| Args: | |
| env: Environment to train the agent on. | |
| max_iterations (int): Maximum number of training iterations. | |
| target_update_interval (int): Number of iterations between updating the target Q-function. | |
| """ | |
| self._max_episode_steps = env.spec.max_episode_steps if hasattr( | |
| env.spec, 'max_episode_steps') else 1000 | |
| buffer = {'states': [], 'actions': [], 'rewards': [], | |
| 'next_states': [], 'dones': [], 'old_probs': []} | |
| for iteration in range(max_iterations): | |
| try: | |
| self.collect_trajectories(env, buffer) | |
| buffer['old_probs'] = self.policy( | |
| torch.stack(buffer['states'])).detach() | |
| self.update_policy(buffer) | |
| self.update_Q_function(buffer) | |
| if iteration % target_update_interval == 0: | |
| self.target_Q_function.load_state_dict( | |
| self.Q_function.state_dict()) | |
| buffer = {'states': [], 'actions': [], 'rewards': [], | |
| 'next_states': [], 'dones': [], 'old_probs': []} | |
| except Exception as e: | |
| logging.error( | |
| f"An error occurred during training iteration: {e}") | |
| continue | |
| def save(self, path): | |
| """ | |
| Save the agent's policy and Q-function models to a file. | |
| Args: | |
| path (str): Path to the file where the models will be saved. | |
| """ | |
| try: | |
| torch.save(self.policy, f"{path}/policy.pt") | |
| torch.save(self.Q_function, f"{path}/Q_function.pt") | |
| except Exception as e: | |
| logging.error(f"An error occurred during saving: {e}") | |
| def load(self, path): | |
| """ | |
| Load the agent's policy and Q-function models from a file. | |
| Args: | |
| path (str): Path to the file where the models are saved. | |
| """ | |
| try: | |
| self.policy = torch.load(f"{path}/policy.pt") | |
| self.Q_function = torch.load(f"{path}/Q_function.pt") | |
| except Exception as e: | |
| logging.error(f"An error occurred during loading: {e}") | |
| def reset_states(self): | |
| self._max_episode_steps = None | |
| self._elapsed_steps = 0 | |
| self._episode_started_at = None | |
| self._episode_ended_at = None | |
| self._episode_reward = 0 | |
| self._episode_step = 0 | |
| self._episode_info = None | |
| self._episode_done = False | |
| self._episode_state = None | |
| self._episode_action = None | |
| self._episode_next_state = None | |
| self._episode_reward = None | |
| self.policy.reset_states() | |
| self.Q_function.reset_states() | |
| self.target_Q_function.reset_states() | |
| def train_agent(market, interval, limit, max_steps, min_balance, profit_target, max_iterations, target_update_interval, path): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.train(env, max_iterations, target_update_interval) | |
| agent.save(path) | |
| def evaluate_agent(market, interval, limit, max_steps, min_balance, profit_target, path): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.load(path) | |
| state = env.reset() | |
| done = False | |
| while not done: | |
| action_probs = agent.policy(torch.tensor(state).unsqueeze(0))[0] | |
| action = torch.multinomial(action_probs, num_samples=1).item() | |
| state, reward, done, info = env.step(action) | |
| print(info) | |
| def run_agent(market, interval, limit, max_steps, min_balance, profit_target, path): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.load(path) | |
| state = env.reset() | |
| done = False | |
| while not done: | |
| action_probs = agent.policy(torch.tensor(state).unsqueeze(0))[0] | |
| action = torch.multinomial(action_probs, num_samples=1).item() | |
| state, reward, done, info = env.step(action) | |
| env.render() | |
| print(info) | |
| def test_agent(market, interval, limit, max_steps, min_balance, profit_target, path): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.load(path) | |
| state = env.reset() | |
| done = False | |
| while not done: | |
| action_probs = agent.policy(torch.tensor(state).unsqueeze(0))[0] | |
| action = torch.multinomial(action_probs, num_samples=1).item() | |
| state, reward, done, info = env.step(action) | |
| print(info) | |
| def optimize_agent(market, interval, limit, max_steps, min_balance, profit_target, max_iterations, target_update_interval, path): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.train(env, max_iterations, target_update_interval) | |
| agent.save(path) | |
| def load_agent(path): | |
| agent = PPOAgent(state_space_dim=limit + 2, action_space_size=7) | |
| agent.load(path) | |
| return agent | |
| def save_agent(agent, path): | |
| agent.save(path) | |
| def create_dataset(market, interval, limit, max_steps, min_balance, profit_target): | |
| env = CryptoTradingEnv(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| historical_data = HistoricalData(market, interval, limit) | |
| data = historical_data.get_data() | |
| dataset = CandlesDataset(data) | |
| return dataset | |
| def load_dataset(path): | |
| dataset = torch.load(path) | |
| return dataset | |
| def save_dataset(dataset, path): | |
| torch.save(dataset, path) | |
| def create_dataloader(dataset, batch_size, shuffle, num_workers): | |
| dataloader = DataLoader(dataset, batch_size=batch_size, | |
| shuffle=shuffle, num_workers=num_workers) | |
| return dataloader | |
| def load_dataloader(path): | |
| dataloader = torch.load(path) | |
| return dataloader | |
| def save_dataloader(dataloader, path): | |
| torch.save(dataloader, path) | |
| def create_model(): | |
| model = nn.Sequential( | |
| nn.Linear(10, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 1) | |
| ) | |
| return model | |
| def load_model(path): | |
| model = torch.load(path) | |
| return model | |
| def save_model(model, path): | |
| torch.save(model, path) | |
| if __name__ == "__main__": | |
| market = "BTC/USDT" | |
| interval = "1h" | |
| limit = 100 | |
| max_steps = 1000 | |
| min_balance = 1000 | |
| profit_target = 1000 | |
| max_iterations = 1000 | |
| target_update_interval = 100 | |
| path = "model.pt" | |
| train_agent(market, interval, limit, max_steps, min_balance, | |
| profit_target, max_iterations, target_update_interval, path) | |
| evaluate_agent(market, interval, limit, max_steps, | |
| min_balance, profit_target, path) | |
| run_agent(market, interval, limit, max_steps, | |
| min_balance, profit_target, path) | |
| test_agent(market, interval, limit, max_steps, | |
| min_balance, profit_target, path) | |
| optimize_agent(market, interval, limit, max_steps, min_balance, | |
| profit_target, max_iterations, target_update_interval, path) | |
| agent = load_agent(path) | |
| save_agent(agent, path) | |
| dataset = create_dataset(market, interval, limit, | |
| max_steps, min_balance, profit_target) | |
| dataset = load_dataset(path) | |
| save_dataset(dataset, path) | |
| dataloader = create_dataloader( | |
| dataset, batch_size=32, shuffle=True, num_workers=4) | |
| dataloader = load_dataloader(path) | |
| save_dataloader(dataloader, path) | |
| model = create_model() | |
| model = load_model(path) | |
| save_model(model, path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment