Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save johnayoung/5eee6f87ef78df3ddcef68e953f8560c to your computer and use it in GitHub Desktop.
Save johnayoung/5eee6f87ef78df3ddcef68e953f8560c to your computer and use it in GitHub Desktop.
cross-sectional-momentum-files
import random
import time
from decimal import Decimal
from typing import Dict, List, Set, Tuple
import numpy as np
from pydantic import Field
from controllers.directional_trading.ema_crossover_v1 import EMACrossoverController, EMACrossoverControllerConfig
from hummingbot.client.config.config_data_types import ClientFieldData
from hummingbot.client.hummingbot_application import HummingbotApplication
from hummingbot.data_feed.candles_feed.candles_factory import CandlesConfig
from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting
from hummingbot.smart_components.executors.rebalance_executor.data_types import (
REBALANCE_EXECUTOR_TYPE,
RebalanceExecutorConfig,
)
from hummingbot.smart_components.models.executor_actions import CreateExecutorAction, ExecutorAction
DEFAULT_SCREENER_ASSETS = ["USDC", "ETH", "SOL", "DOGE", "NEAR", "RNDR", "ADA", "AVAX", "XRP", "FET", "XMR"]
class CrossMomentumWithTrendOverlayControllerConfig(EMACrossoverControllerConfig):
controller_name = "cross_momentum_with_trend_overlay_v1"
connector_name: str = Field(
default="kraken",
client_data=ClientFieldData(
prompt_on_new=True, prompt=lambda mi: "Enter the name of the exchange to trade on (e.g., kraken):"
),
)
quote_asset: str = Field(
default="USD",
client_data=ClientFieldData(
prompt_on_new=True, prompt=lambda mi: "Enter the target quote asset for the portfolio:"
),
)
quote_weight: float = Field(
default=0.05,
client_data=ClientFieldData(
prompt_on_new=True, prompt=lambda mi: "Enter the target weight of the quote asset in the portfolio:"
),
)
min_order_amount_to_rebalance_quote: Decimal = Field(
default=Decimal("0.01"),
client_data=ClientFieldData(
prompt_on_new=True, prompt=lambda mi: "Enter the minimum order size in quote asset for the exchange:"
),
)
screener_assets: str = Field(
default=",".join(DEFAULT_SCREENER_ASSETS),
client_data=ClientFieldData(
prompt_on_new=True, prompt=lambda mi: "Enter the assets to use for the screener universe:"
),
)
screener_interval: str = Field(
default="1d",
client_data=ClientFieldData(
prompt=lambda mi: "Enter the interval for the screener data (e.g., 1m, 5m, 1h, 1d): ", prompt_on_new=True
),
)
screener_lookback_period: int = Field(
default=5,
gt=0,
client_data=ClientFieldData(prompt=lambda mi: "Enter the lookback period (e.g. 5): ", prompt_on_new=True),
)
cooldown_time: int = Field(
default=60 * 5,
gt=0,
client_data=ClientFieldData(
is_updatable=True,
prompt_on_new=False,
prompt=lambda mi: "Specify the cooldown time in seconds after executing a rebalance (e.g., 300 for 5 minutes):",
),
)
def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]:
if self.connector_name not in markets:
markets[self.connector_name] = set()
markets[self.connector_name].add(self.ema_trading_pair)
for asset in self.screener_assets.split(","):
trading_pair = f"{asset}-{self.quote_asset}"
markets[self.connector_name].add(trading_pair)
return markets
class CrossMomentumWithTrendOverlayController(EMACrossoverController):
def __init__(self, config: CrossMomentumWithTrendOverlayControllerConfig, *args, **kwargs):
self.config = config
# Add screener assets to candles_config
for asset in config.screener_assets.split(","):
self.config.candles_config = [
CandlesConfig(
connector=config.connector_name,
trading_pair=f"{asset}-{config.quote_asset}",
interval=config.screener_interval,
max_records=config.screener_lookback_period,
)
]
super().__init__(config, *args, **kwargs)
async def update_processed_data(self):
await super().update_processed_data()
def get_current_balances(self) -> Dict[str, float]:
hb = HummingbotApplication.main_application()
return hb.markets[self.config.connector_name].get_all_balances()
def get_current_assets(self) -> Set[str]:
return set(self.get_current_balances().keys())
def get_target_assets(self) -> Set[str]:
# TODO
return set(random.sample(self.config.screener_assets.split(","), 5))
def get_assets_to_close(self, current_assets: Set[str], target_assets: Set[str]) -> Set[str]:
return set(asset for asset in current_assets if asset not in target_assets)
def calculate_weights_data(self) -> Tuple[Set[str], Set[str], Set[str], Set[str]]:
current_assets = self.get_current_assets()
target_assets = self.get_target_assets()
assets_to_close = self.get_assets_to_close(current_assets, target_assets)
all_assets = current_assets.union(target_assets)
return current_assets, target_assets, assets_to_close, all_assets
def calculate_target_weights(self, to_quote: bool = False) -> Dict[str, float]:
_, _, assets_to_close, all_assets = self.calculate_weights_data()
if to_quote:
target_weights = {asset: 0.0 for asset in all_assets}
target_weights[self.config.quote_asset] = 1.0
return target_weights
weighting_strategy = EqualWeighting()
weights = weighting_strategy.calculate_weights(assets=all_assets, data={"excluded_assets": assets_to_close})
# Assert that the sum of weights is 1
assert np.isclose(sum(weights.values()), 1.0), f"Sum of weights is not 1: {sum(weights.values())}"
return weights
def determine_executor_actions(self) -> List[ExecutorAction]:
"""
Determine actions based on the provided executor handler report.
"""
actions = []
actions.extend(self.create_actions_proposal())
actions.extend(self.stop_actions_proposal())
return actions
def create_actions_proposal(self) -> List[ExecutorAction]:
"""
Create actions based on the provided executor handler report.
"""
create_actions = []
signal = self.processed_data["signal"]
if signal != 0 and self.can_create_rebalance_executor(signal):
to_quote_condition = signal < 0
create_actions.append(
CreateExecutorAction(
controller_id=self.config.id,
executor_config=self.get_rebalance_executor_config(
target_weights=self.calculate_target_weights(to_quote=to_quote_condition)
),
)
)
return create_actions
def stop_actions_proposal(self) -> List[ExecutorAction]:
"""
Stop actions based on the provided executor handler report.
"""
stop_actions = []
return stop_actions
def can_create_rebalance_executor(self, signal: int) -> bool:
"""
Check if a rebalance executor can be created. Only one rebalance executor is allowed at a time.
"""
active_executors = self.filter_executors(
executors=self.executors_info, filter_func=lambda x: x.is_active and x.type == REBALANCE_EXECUTOR_TYPE
)
max_timestamp = max([executor.timestamp for executor in active_executors], default=0)
active_executors_condition = len(active_executors) == 0
cooldown_condition = time.time() - max_timestamp > self.config.cooldown_time
return active_executors_condition and cooldown_condition
def get_rebalance_executor_config(self, target_weights: Dict[str, float]) -> RebalanceExecutorConfig:
"""
Get the rebalance executor config.
"""
return RebalanceExecutorConfig(
timestamp=time.time(),
connector_name=self.config.connector_name,
target_weights=target_weights,
quote_asset=self.config.quote_asset,
quote_weight=self.config.quote_weight,
min_order_amount_to_rebalance_quote=self.config.min_order_amount_to_rebalance_quote,
)
from typing import Dict, List, Set
import pandas as pd
from pydantic import Field
from hummingbot.client.config.config_data_types import ClientFieldData
from hummingbot.data_feed.candles_feed.candles_factory import CandlesConfig
from hummingbot.smart_components.controllers.controller_base import ControllerBase, ControllerConfigBase
class TestModeOptions:
ALWAYS_REBALANCE = 1
STANDARD = 0
ALWAYS_SELL = -1
class EMACrossoverControllerConfig(ControllerConfigBase):
controller_type = "directional_trading"
controller_name = "ema_crossover_v1"
connector_name: str = Field(
default="kraken",
client_data=ClientFieldData(
prompt_on_new=True,
prompt=lambda mi: "Enter the name of the exchange to trade on (e.g., kraken):",
),
)
candles_config: List[CandlesConfig] = []
ema_trading_pair: str = Field(
default="BTC-USDC",
client_data=ClientFieldData(
prompt_on_new=True,
prompt=lambda mi: "Enter the trading pair for the candles data: ",
),
)
ema_candles_interval: str = Field(
default="1m",
client_data=ClientFieldData(
prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", prompt_on_new=False
),
)
ema_fast: int = Field(
default=5,
gt=0,
client_data=ClientFieldData(prompt=lambda mi: "Enter the fast EMA period (e.g. 5): ", prompt_on_new=True),
)
ema_slow: int = Field(
default=50,
gt=0,
client_data=ClientFieldData(prompt=lambda mi: "Enter the slow EMA period (e.g. 50): ", prompt_on_new=True),
)
test_mode: int = Field(
default=TestModeOptions.STANDARD,
gt=-2,
lt=2,
client_data=ClientFieldData(
prompt_on_new=True,
prompt=lambda mi: "Enter the test mode (1 = always rebalance, 0 = standard, -1 = always sell):",
),
)
@property
def max_records(self) -> int:
return self.ema_slow + 30
def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]:
if self.connector_name not in markets:
markets[self.connector_name] = set()
markets[self.connector_name].add(self.ema_trading_pair)
return markets
class EMACrossoverController(ControllerBase):
def __init__(self, config: EMACrossoverControllerConfig, *args, **kwargs):
self.config = config
if len(self.config.candles_config) == 0:
self.config.candles_config = [
CandlesConfig(
connector=config.connector_name,
trading_pair=config.ema_trading_pair,
interval=config.ema_candles_interval,
max_records=config.max_records,
)
]
super().__init__(config, *args, **kwargs)
def get_processed_data(self) -> pd.DataFrame:
df = self.market_data_provider.get_candles_df(
self.config.connector_name,
self.config.ema_trading_pair,
self.config.ema_candles_interval,
self.config.max_records,
)
if df.empty or len(df) < 2:
self.logger().warning("Empty dataframe received from get_candles_df.")
return df
df["fast_ema"] = df.ta.ema(length=self.config.ema_fast)
df["slow_ema"] = df.ta.ema(length=self.config.ema_slow)
return df
def get_signal(self) -> int:
if self.config.test_mode == TestModeOptions.ALWAYS_REBALANCE:
return 1
elif self.config.test_mode == TestModeOptions.ALWAYS_SELL:
return -1
df = self.get_processed_data()
last = df.iloc[-1]
prev = df.iloc[-2]
buy_condition = last["fast_ema"] > last["slow_ema"] and prev["fast_ema"] <= prev["slow_ema"]
sell_condition = last["fast_ema"] < last["slow_ema"] and prev["fast_ema"] >= prev["slow_ema"]
if buy_condition:
return 1 # Trigger rebalance
elif sell_condition:
return -1 # Trigger position close
return 0 # No action
async def update_processed_data(self):
"""
Update the processed data based on the current state of the strategy.
"""
signal = self.get_signal()
self.processed_data["signal"] = signal
import asyncio
from decimal import Decimal
from typing import Dict, List
from hummingbot.core.data_type.common import OrderType, PriceType, TradeType
from hummingbot.logger.logger import HummingbotLogger
from hummingbot.smart_components.executors.executor_base import ExecutorBase
from hummingbot.smart_components.executors.rebalance_executor.data_types import (
RebalanceExecutorConfig,
RebalanceExecutorStatus,
)
from hummingbot.smart_components.models.executors import TrackedOrder
from hummingbot.strategy.script_strategy_base import ScriptStrategyBase
class RebalanceAction:
asset: str
amount: Decimal
side: TradeType
def __init__(self, asset: str, amount: Decimal, side: TradeType):
self.asset = asset
self.amount = amount
self.side = side
class RebalanceExecutor(ExecutorBase):
_logger = None
@classmethod
def logger(cls) -> HummingbotLogger:
if cls._logger is None:
cls._logger = HummingbotLogger(__name__)
return cls._logger
@property
def is_closed(self):
return self.rebalance_status in [RebalanceExecutorStatus.COMPLETED, RebalanceExecutorStatus.FAILED]
def __init__(self, strategy: ScriptStrategyBase, config: RebalanceExecutorConfig, update_interval: float = 1.0):
super().__init__(
strategy=strategy, connectors=[config.connector_name], config=config, update_interval=update_interval
)
self.config = config
self.target_weights = {asset: Decimal(weight) for asset, weight in config.target_weights.items()}
self.quote_asset = config.quote_asset
self.quote_weight = config.quote_weight
self.min_order_amount_to_rebalance_quote = config.min_order_amount_to_rebalance_quote
self.rebalance_status = RebalanceExecutorStatus.INITIALIZING
self.tracked_orders: Dict[str, TrackedOrder] = {}
@property
def current_balances(self) -> Dict[str, Decimal]:
cb = self.connectors[self.config.connector_name].get_all_balances()
return {asset: Decimal(amount) for asset, amount in cb.items()}
def validate_sufficient_balance(self):
# TODO
pass
def get_net_pnl_quote(self) -> Decimal:
"""
TODO: Returns the net profit or loss in quote currency.
"""
return Decimal(0)
def get_net_pnl_pct(self) -> Decimal:
"""
TODO: Returns the net profit or loss in percentage.
"""
return Decimal(0)
def get_cum_fees_quote(self) -> Decimal:
"""
Returns the cumulative fees in quote currency.
"""
return Decimal(0)
def get_trading_pair(self, asset: str) -> str:
return f"{asset}-{self.quote_asset}"
def get_asset_value_in_quote(self, asset: str, amount: Decimal) -> Decimal:
# Fetches the price of the asset in terms of the quote asset
price = self.get_price(
connector_name=self.config.connector_name,
trading_pair=self.get_trading_pair(asset),
price_type=PriceType.MidPrice,
)
return amount * price
def calculate_total_portfolio_value(self):
total_value = Decimal(0)
for asset, amount in self.current_balances.items():
if asset in self.target_weights or asset == self.quote_asset:
if asset != self.quote_asset:
total_value += self.get_asset_value_in_quote(asset, amount)
else:
total_value += amount
return total_value
def calculate_total_rebalance_value(self):
total_value = self.calculate_total_portfolio_value()
rebalance_weight = Decimal((1 - self.quote_weight))
return total_value * rebalance_weight
def calculate_rebalance_actions(self) -> List[RebalanceAction]:
total_rebalance_value = self.calculate_total_rebalance_value()
target_values = {asset: total_rebalance_value * weight for asset, weight in self.target_weights.items()}
trade_actions = []
for asset, target_value in target_values.items():
if asset == self.quote_asset:
continue # Skip the quote asset
current_value = self.get_asset_value_in_quote(asset, self.current_balances.get(asset, Decimal(0)))
amount_in_quote = target_value - current_value
asset_price = self.get_price(
connector_name=self.config.connector_name,
trading_pair=self.get_trading_pair(asset),
price_type=PriceType.MidPrice,
)
amount = amount_in_quote / asset_price
is_non_zero_amount = abs(amount) > 0
is_above_min_order_amount = abs(amount_in_quote) >= self.min_order_amount_to_rebalance_quote
if is_non_zero_amount and is_above_min_order_amount:
side = TradeType.BUY if amount > 0 else TradeType.SELL
trade_actions.append(RebalanceAction(asset=asset, amount=amount, side=side))
return trade_actions
async def control_task(self):
try:
if self.is_closed:
return
self.rebalance_status = RebalanceExecutorStatus.SELLING
actions = self.calculate_rebalance_actions()
sell_actions = [action for action in actions if action.side == TradeType.SELL]
buy_actions = [action for action in actions if action.side == TradeType.BUY]
sell_tasks = [self.place_order_and_wait(action) for action in sell_actions]
await asyncio.gather(*sell_tasks)
self.rebalance_status = RebalanceExecutorStatus.BUYING
buy_tasks = [self.place_order_and_wait(action) for action in buy_actions]
await asyncio.gather(*buy_tasks)
self.rebalance_status = RebalanceExecutorStatus.COMPLETED
except Exception as e:
self.logger().error(f"Error in rebalance executor: {str(e)}")
self.rebalance_status = RebalanceExecutorStatus.FAILED
async def place_order_and_wait(self, action: RebalanceAction):
asset = action.asset
amount = abs(action.amount)
side = action.side
try:
order_id = self.place_order(
connector_name=self.config.connector_name,
trading_pair=self.get_trading_pair(asset),
order_type=OrderType.MARKET,
side=side,
amount=amount,
)
if order_id:
self.update_tracked_order(order_id)
await self.wait_for_order_completion(order_id)
except Exception as e:
self.logger().error(f"Error placing order for {asset}: {str(e)}")
self.rebalance_status = RebalanceExecutorStatus.FAILED
raise e
async def wait_for_order_completion(self, order_id: str):
while not self.is_order_complete(order_id):
await asyncio.sleep(1) # Check order status every second
def is_order_complete(self, order_id: str) -> bool:
tracked_order = self.update_tracked_order(order_id)
if tracked_order and tracked_order.is_done:
if tracked_order.order and tracked_order.order.is_failure:
# Raise an exception if the order has failed
raise Exception(f"Order {order_id} failed: ")
return True
return False
def get_custom_info(self) -> Dict:
return {
"rebalance_status": self.rebalance_status,
"current_balances": self.current_balances,
"target_weights": self.target_weights,
"quote_asset": self.quote_asset,
"quote_weight": self.quote_weight,
"min_order_amount_to_rebalance_quote": self.min_order_amount_to_rebalance_quote,
}
def update_tracked_order(self, order_id: str):
tracked_order = self.tracked_orders.get(order_id)
if tracked_order is None:
tracked_order = TrackedOrder(order_id=order_id)
self.tracked_orders[order_id] = tracked_order
if tracked_order.order is None:
in_flight_order = self.get_in_flight_order(connector_name=self.config.connector_name, order_id=order_id)
if in_flight_order:
tracked_order.order = in_flight_order
return tracked_order
@fengyelingdu
Copy link

Hello, I'm very excited to see this strategy, it's really great. I want to ask about cross_momentum_with_trend_overlay_v1.py,
from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting
from hummingbot.smart_components.executors.rebalance_executor.data_types import (
REBALANCE_EXECUTOR_TYPE,
RebalanceExecutorConfig,
)
These two imports are throwing errors, can these two data_types classes be provided, thank you

@fengyelingdu
Copy link

Hello, I'm very excited to see this strategy, it's really great. I want to ask about cross_momentum_with_trend_overlay_v1.py, from hummingbot.smart_components.controllers.data_types.data_types import EqualWeighting from hummingbot.smart_components.executors.rebalance_executor.data_types import ( REBALANCE_EXECUTOR_TYPE, RebalanceExecutorConfig, ) These two imports are throwing errors, can these two data_types classes be provided, thank you

I saw these files in your GitHub repository, thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment