Created
May 11, 2024 20:02
-
-
Save johnayoung/5eee6f87ef78df3ddcef68e953f8560c to your computer and use it in GitHub Desktop.
cross-sectional-momentum-files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import time | |
from decimal import Decimal | |
from typing import Dict, List, Set, Tuple | |
import numpy as np | |
from pydantic import Field | |
from controllers.directional_trading.ema_crossover_v1 import EMACrossoverController, EMACrossoverControllerConfig | |
from hummingbot.client.config.config_data_types import ClientFieldData | |
from hummingbot.client.hummingbot_application import HummingbotApplication | |
from hummingbot.data_feed.candles_feed.candles_factory import CandlesConfig | |
from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting | |
from hummingbot.smart_components.executors.rebalance_executor.data_types import ( | |
REBALANCE_EXECUTOR_TYPE, | |
RebalanceExecutorConfig, | |
) | |
from hummingbot.smart_components.models.executor_actions import CreateExecutorAction, ExecutorAction | |
DEFAULT_SCREENER_ASSETS = ["USDC", "ETH", "SOL", "DOGE", "NEAR", "RNDR", "ADA", "AVAX", "XRP", "FET", "XMR"] | |
class CrossMomentumWithTrendOverlayControllerConfig(EMACrossoverControllerConfig): | |
controller_name = "cross_momentum_with_trend_overlay_v1" | |
connector_name: str = Field( | |
default="kraken", | |
client_data=ClientFieldData( | |
prompt_on_new=True, prompt=lambda mi: "Enter the name of the exchange to trade on (e.g., kraken):" | |
), | |
) | |
quote_asset: str = Field( | |
default="USD", | |
client_data=ClientFieldData( | |
prompt_on_new=True, prompt=lambda mi: "Enter the target quote asset for the portfolio:" | |
), | |
) | |
quote_weight: float = Field( | |
default=0.05, | |
client_data=ClientFieldData( | |
prompt_on_new=True, prompt=lambda mi: "Enter the target weight of the quote asset in the portfolio:" | |
), | |
) | |
min_order_amount_to_rebalance_quote: Decimal = Field( | |
default=Decimal("0.01"), | |
client_data=ClientFieldData( | |
prompt_on_new=True, prompt=lambda mi: "Enter the minimum order size in quote asset for the exchange:" | |
), | |
) | |
screener_assets: str = Field( | |
default=",".join(DEFAULT_SCREENER_ASSETS), | |
client_data=ClientFieldData( | |
prompt_on_new=True, prompt=lambda mi: "Enter the assets to use for the screener universe:" | |
), | |
) | |
screener_interval: str = Field( | |
default="1d", | |
client_data=ClientFieldData( | |
prompt=lambda mi: "Enter the interval for the screener data (e.g., 1m, 5m, 1h, 1d): ", prompt_on_new=True | |
), | |
) | |
screener_lookback_period: int = Field( | |
default=5, | |
gt=0, | |
client_data=ClientFieldData(prompt=lambda mi: "Enter the lookback period (e.g. 5): ", prompt_on_new=True), | |
) | |
cooldown_time: int = Field( | |
default=60 * 5, | |
gt=0, | |
client_data=ClientFieldData( | |
is_updatable=True, | |
prompt_on_new=False, | |
prompt=lambda mi: "Specify the cooldown time in seconds after executing a rebalance (e.g., 300 for 5 minutes):", | |
), | |
) | |
def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: | |
if self.connector_name not in markets: | |
markets[self.connector_name] = set() | |
markets[self.connector_name].add(self.ema_trading_pair) | |
for asset in self.screener_assets.split(","): | |
trading_pair = f"{asset}-{self.quote_asset}" | |
markets[self.connector_name].add(trading_pair) | |
return markets | |
class CrossMomentumWithTrendOverlayController(EMACrossoverController): | |
def __init__(self, config: CrossMomentumWithTrendOverlayControllerConfig, *args, **kwargs): | |
self.config = config | |
# Add screener assets to candles_config | |
for asset in config.screener_assets.split(","): | |
self.config.candles_config = [ | |
CandlesConfig( | |
connector=config.connector_name, | |
trading_pair=f"{asset}-{config.quote_asset}", | |
interval=config.screener_interval, | |
max_records=config.screener_lookback_period, | |
) | |
] | |
super().__init__(config, *args, **kwargs) | |
async def update_processed_data(self): | |
await super().update_processed_data() | |
def get_current_balances(self) -> Dict[str, float]: | |
hb = HummingbotApplication.main_application() | |
return hb.markets[self.config.connector_name].get_all_balances() | |
def get_current_assets(self) -> Set[str]: | |
return set(self.get_current_balances().keys()) | |
def get_target_assets(self) -> Set[str]: | |
# TODO | |
return set(random.sample(self.config.screener_assets.split(","), 5)) | |
def get_assets_to_close(self, current_assets: Set[str], target_assets: Set[str]) -> Set[str]: | |
return set(asset for asset in current_assets if asset not in target_assets) | |
def calculate_weights_data(self) -> Tuple[Set[str], Set[str], Set[str], Set[str]]: | |
current_assets = self.get_current_assets() | |
target_assets = self.get_target_assets() | |
assets_to_close = self.get_assets_to_close(current_assets, target_assets) | |
all_assets = current_assets.union(target_assets) | |
return current_assets, target_assets, assets_to_close, all_assets | |
def calculate_target_weights(self, to_quote: bool = False) -> Dict[str, float]: | |
_, _, assets_to_close, all_assets = self.calculate_weights_data() | |
if to_quote: | |
target_weights = {asset: 0.0 for asset in all_assets} | |
target_weights[self.config.quote_asset] = 1.0 | |
return target_weights | |
weighting_strategy = EqualWeighting() | |
weights = weighting_strategy.calculate_weights(assets=all_assets, data={"excluded_assets": assets_to_close}) | |
# Assert that the sum of weights is 1 | |
assert np.isclose(sum(weights.values()), 1.0), f"Sum of weights is not 1: {sum(weights.values())}" | |
return weights | |
def determine_executor_actions(self) -> List[ExecutorAction]: | |
""" | |
Determine actions based on the provided executor handler report. | |
""" | |
actions = [] | |
actions.extend(self.create_actions_proposal()) | |
actions.extend(self.stop_actions_proposal()) | |
return actions | |
def create_actions_proposal(self) -> List[ExecutorAction]: | |
""" | |
Create actions based on the provided executor handler report. | |
""" | |
create_actions = [] | |
signal = self.processed_data["signal"] | |
if signal != 0 and self.can_create_rebalance_executor(signal): | |
to_quote_condition = signal < 0 | |
create_actions.append( | |
CreateExecutorAction( | |
controller_id=self.config.id, | |
executor_config=self.get_rebalance_executor_config( | |
target_weights=self.calculate_target_weights(to_quote=to_quote_condition) | |
), | |
) | |
) | |
return create_actions | |
def stop_actions_proposal(self) -> List[ExecutorAction]: | |
""" | |
Stop actions based on the provided executor handler report. | |
""" | |
stop_actions = [] | |
return stop_actions | |
def can_create_rebalance_executor(self, signal: int) -> bool: | |
""" | |
Check if a rebalance executor can be created. Only one rebalance executor is allowed at a time. | |
""" | |
active_executors = self.filter_executors( | |
executors=self.executors_info, filter_func=lambda x: x.is_active and x.type == REBALANCE_EXECUTOR_TYPE | |
) | |
max_timestamp = max([executor.timestamp for executor in active_executors], default=0) | |
active_executors_condition = len(active_executors) == 0 | |
cooldown_condition = time.time() - max_timestamp > self.config.cooldown_time | |
return active_executors_condition and cooldown_condition | |
def get_rebalance_executor_config(self, target_weights: Dict[str, float]) -> RebalanceExecutorConfig: | |
""" | |
Get the rebalance executor config. | |
""" | |
return RebalanceExecutorConfig( | |
timestamp=time.time(), | |
connector_name=self.config.connector_name, | |
target_weights=target_weights, | |
quote_asset=self.config.quote_asset, | |
quote_weight=self.config.quote_weight, | |
min_order_amount_to_rebalance_quote=self.config.min_order_amount_to_rebalance_quote, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Set | |
import pandas as pd | |
from pydantic import Field | |
from hummingbot.client.config.config_data_types import ClientFieldData | |
from hummingbot.data_feed.candles_feed.candles_factory import CandlesConfig | |
from hummingbot.smart_components.controllers.controller_base import ControllerBase, ControllerConfigBase | |
class TestModeOptions: | |
ALWAYS_REBALANCE = 1 | |
STANDARD = 0 | |
ALWAYS_SELL = -1 | |
class EMACrossoverControllerConfig(ControllerConfigBase): | |
controller_type = "directional_trading" | |
controller_name = "ema_crossover_v1" | |
connector_name: str = Field( | |
default="kraken", | |
client_data=ClientFieldData( | |
prompt_on_new=True, | |
prompt=lambda mi: "Enter the name of the exchange to trade on (e.g., kraken):", | |
), | |
) | |
candles_config: List[CandlesConfig] = [] | |
ema_trading_pair: str = Field( | |
default="BTC-USDC", | |
client_data=ClientFieldData( | |
prompt_on_new=True, | |
prompt=lambda mi: "Enter the trading pair for the candles data: ", | |
), | |
) | |
ema_candles_interval: str = Field( | |
default="1m", | |
client_data=ClientFieldData( | |
prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", prompt_on_new=False | |
), | |
) | |
ema_fast: int = Field( | |
default=5, | |
gt=0, | |
client_data=ClientFieldData(prompt=lambda mi: "Enter the fast EMA period (e.g. 5): ", prompt_on_new=True), | |
) | |
ema_slow: int = Field( | |
default=50, | |
gt=0, | |
client_data=ClientFieldData(prompt=lambda mi: "Enter the slow EMA period (e.g. 50): ", prompt_on_new=True), | |
) | |
test_mode: int = Field( | |
default=TestModeOptions.STANDARD, | |
gt=-2, | |
lt=2, | |
client_data=ClientFieldData( | |
prompt_on_new=True, | |
prompt=lambda mi: "Enter the test mode (1 = always rebalance, 0 = standard, -1 = always sell):", | |
), | |
) | |
@property | |
def max_records(self) -> int: | |
return self.ema_slow + 30 | |
def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: | |
if self.connector_name not in markets: | |
markets[self.connector_name] = set() | |
markets[self.connector_name].add(self.ema_trading_pair) | |
return markets | |
class EMACrossoverController(ControllerBase): | |
def __init__(self, config: EMACrossoverControllerConfig, *args, **kwargs): | |
self.config = config | |
if len(self.config.candles_config) == 0: | |
self.config.candles_config = [ | |
CandlesConfig( | |
connector=config.connector_name, | |
trading_pair=config.ema_trading_pair, | |
interval=config.ema_candles_interval, | |
max_records=config.max_records, | |
) | |
] | |
super().__init__(config, *args, **kwargs) | |
def get_processed_data(self) -> pd.DataFrame: | |
df = self.market_data_provider.get_candles_df( | |
self.config.connector_name, | |
self.config.ema_trading_pair, | |
self.config.ema_candles_interval, | |
self.config.max_records, | |
) | |
if df.empty or len(df) < 2: | |
self.logger().warning("Empty dataframe received from get_candles_df.") | |
return df | |
df["fast_ema"] = df.ta.ema(length=self.config.ema_fast) | |
df["slow_ema"] = df.ta.ema(length=self.config.ema_slow) | |
return df | |
def get_signal(self) -> int: | |
if self.config.test_mode == TestModeOptions.ALWAYS_REBALANCE: | |
return 1 | |
elif self.config.test_mode == TestModeOptions.ALWAYS_SELL: | |
return -1 | |
df = self.get_processed_data() | |
last = df.iloc[-1] | |
prev = df.iloc[-2] | |
buy_condition = last["fast_ema"] > last["slow_ema"] and prev["fast_ema"] <= prev["slow_ema"] | |
sell_condition = last["fast_ema"] < last["slow_ema"] and prev["fast_ema"] >= prev["slow_ema"] | |
if buy_condition: | |
return 1 # Trigger rebalance | |
elif sell_condition: | |
return -1 # Trigger position close | |
return 0 # No action | |
async def update_processed_data(self): | |
""" | |
Update the processed data based on the current state of the strategy. | |
""" | |
signal = self.get_signal() | |
self.processed_data["signal"] = signal |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from decimal import Decimal | |
from typing import Dict, List | |
from hummingbot.core.data_type.common import OrderType, PriceType, TradeType | |
from hummingbot.logger.logger import HummingbotLogger | |
from hummingbot.smart_components.executors.executor_base import ExecutorBase | |
from hummingbot.smart_components.executors.rebalance_executor.data_types import ( | |
RebalanceExecutorConfig, | |
RebalanceExecutorStatus, | |
) | |
from hummingbot.smart_components.models.executors import TrackedOrder | |
from hummingbot.strategy.script_strategy_base import ScriptStrategyBase | |
class RebalanceAction: | |
asset: str | |
amount: Decimal | |
side: TradeType | |
def __init__(self, asset: str, amount: Decimal, side: TradeType): | |
self.asset = asset | |
self.amount = amount | |
self.side = side | |
class RebalanceExecutor(ExecutorBase): | |
_logger = None | |
@classmethod | |
def logger(cls) -> HummingbotLogger: | |
if cls._logger is None: | |
cls._logger = HummingbotLogger(__name__) | |
return cls._logger | |
@property | |
def is_closed(self): | |
return self.rebalance_status in [RebalanceExecutorStatus.COMPLETED, RebalanceExecutorStatus.FAILED] | |
def __init__(self, strategy: ScriptStrategyBase, config: RebalanceExecutorConfig, update_interval: float = 1.0): | |
super().__init__( | |
strategy=strategy, connectors=[config.connector_name], config=config, update_interval=update_interval | |
) | |
self.config = config | |
self.target_weights = {asset: Decimal(weight) for asset, weight in config.target_weights.items()} | |
self.quote_asset = config.quote_asset | |
self.quote_weight = config.quote_weight | |
self.min_order_amount_to_rebalance_quote = config.min_order_amount_to_rebalance_quote | |
self.rebalance_status = RebalanceExecutorStatus.INITIALIZING | |
self.tracked_orders: Dict[str, TrackedOrder] = {} | |
@property | |
def current_balances(self) -> Dict[str, Decimal]: | |
cb = self.connectors[self.config.connector_name].get_all_balances() | |
return {asset: Decimal(amount) for asset, amount in cb.items()} | |
def validate_sufficient_balance(self): | |
# TODO | |
pass | |
def get_net_pnl_quote(self) -> Decimal: | |
""" | |
TODO: Returns the net profit or loss in quote currency. | |
""" | |
return Decimal(0) | |
def get_net_pnl_pct(self) -> Decimal: | |
""" | |
TODO: Returns the net profit or loss in percentage. | |
""" | |
return Decimal(0) | |
def get_cum_fees_quote(self) -> Decimal: | |
""" | |
Returns the cumulative fees in quote currency. | |
""" | |
return Decimal(0) | |
def get_trading_pair(self, asset: str) -> str: | |
return f"{asset}-{self.quote_asset}" | |
def get_asset_value_in_quote(self, asset: str, amount: Decimal) -> Decimal: | |
# Fetches the price of the asset in terms of the quote asset | |
price = self.get_price( | |
connector_name=self.config.connector_name, | |
trading_pair=self.get_trading_pair(asset), | |
price_type=PriceType.MidPrice, | |
) | |
return amount * price | |
def calculate_total_portfolio_value(self): | |
total_value = Decimal(0) | |
for asset, amount in self.current_balances.items(): | |
if asset in self.target_weights or asset == self.quote_asset: | |
if asset != self.quote_asset: | |
total_value += self.get_asset_value_in_quote(asset, amount) | |
else: | |
total_value += amount | |
return total_value | |
def calculate_total_rebalance_value(self): | |
total_value = self.calculate_total_portfolio_value() | |
rebalance_weight = Decimal((1 - self.quote_weight)) | |
return total_value * rebalance_weight | |
def calculate_rebalance_actions(self) -> List[RebalanceAction]: | |
total_rebalance_value = self.calculate_total_rebalance_value() | |
target_values = {asset: total_rebalance_value * weight for asset, weight in self.target_weights.items()} | |
trade_actions = [] | |
for asset, target_value in target_values.items(): | |
if asset == self.quote_asset: | |
continue # Skip the quote asset | |
current_value = self.get_asset_value_in_quote(asset, self.current_balances.get(asset, Decimal(0))) | |
amount_in_quote = target_value - current_value | |
asset_price = self.get_price( | |
connector_name=self.config.connector_name, | |
trading_pair=self.get_trading_pair(asset), | |
price_type=PriceType.MidPrice, | |
) | |
amount = amount_in_quote / asset_price | |
is_non_zero_amount = abs(amount) > 0 | |
is_above_min_order_amount = abs(amount_in_quote) >= self.min_order_amount_to_rebalance_quote | |
if is_non_zero_amount and is_above_min_order_amount: | |
side = TradeType.BUY if amount > 0 else TradeType.SELL | |
trade_actions.append(RebalanceAction(asset=asset, amount=amount, side=side)) | |
return trade_actions | |
async def control_task(self): | |
try: | |
if self.is_closed: | |
return | |
self.rebalance_status = RebalanceExecutorStatus.SELLING | |
actions = self.calculate_rebalance_actions() | |
sell_actions = [action for action in actions if action.side == TradeType.SELL] | |
buy_actions = [action for action in actions if action.side == TradeType.BUY] | |
sell_tasks = [self.place_order_and_wait(action) for action in sell_actions] | |
await asyncio.gather(*sell_tasks) | |
self.rebalance_status = RebalanceExecutorStatus.BUYING | |
buy_tasks = [self.place_order_and_wait(action) for action in buy_actions] | |
await asyncio.gather(*buy_tasks) | |
self.rebalance_status = RebalanceExecutorStatus.COMPLETED | |
except Exception as e: | |
self.logger().error(f"Error in rebalance executor: {str(e)}") | |
self.rebalance_status = RebalanceExecutorStatus.FAILED | |
async def place_order_and_wait(self, action: RebalanceAction): | |
asset = action.asset | |
amount = abs(action.amount) | |
side = action.side | |
try: | |
order_id = self.place_order( | |
connector_name=self.config.connector_name, | |
trading_pair=self.get_trading_pair(asset), | |
order_type=OrderType.MARKET, | |
side=side, | |
amount=amount, | |
) | |
if order_id: | |
self.update_tracked_order(order_id) | |
await self.wait_for_order_completion(order_id) | |
except Exception as e: | |
self.logger().error(f"Error placing order for {asset}: {str(e)}") | |
self.rebalance_status = RebalanceExecutorStatus.FAILED | |
raise e | |
async def wait_for_order_completion(self, order_id: str): | |
while not self.is_order_complete(order_id): | |
await asyncio.sleep(1) # Check order status every second | |
def is_order_complete(self, order_id: str) -> bool: | |
tracked_order = self.update_tracked_order(order_id) | |
if tracked_order and tracked_order.is_done: | |
if tracked_order.order and tracked_order.order.is_failure: | |
# Raise an exception if the order has failed | |
raise Exception(f"Order {order_id} failed: ") | |
return True | |
return False | |
def get_custom_info(self) -> Dict: | |
return { | |
"rebalance_status": self.rebalance_status, | |
"current_balances": self.current_balances, | |
"target_weights": self.target_weights, | |
"quote_asset": self.quote_asset, | |
"quote_weight": self.quote_weight, | |
"min_order_amount_to_rebalance_quote": self.min_order_amount_to_rebalance_quote, | |
} | |
def update_tracked_order(self, order_id: str): | |
tracked_order = self.tracked_orders.get(order_id) | |
if tracked_order is None: | |
tracked_order = TrackedOrder(order_id=order_id) | |
self.tracked_orders[order_id] = tracked_order | |
if tracked_order.order is None: | |
in_flight_order = self.get_in_flight_order(connector_name=self.config.connector_name, order_id=order_id) | |
if in_flight_order: | |
tracked_order.order = in_flight_order | |
return tracked_order |
Hello, I'm very excited to see this strategy, it's really great. I want to ask about cross_momentum_with_trend_overlay_v1.py, from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting from hummingbot.smart_components.executors.rebalance_executor.data_types import ( REBALANCE_EXECUTOR_TYPE, RebalanceExecutorConfig, ) These two imports are throwing errors, can these two data_types classes be provided, thank you
I saw these files in your GitHub repository, thank you
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, I'm very excited to see this strategy, it's really great. I want to ask about cross_momentum_with_trend_overlay_v1.py,
from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting
from hummingbot.smart_components.executors.rebalance_executor.data_types import (
REBALANCE_EXECUTOR_TYPE,
RebalanceExecutorConfig,
)
These two imports are throwing errors, can these two data_types classes be provided, thank you