Created
December 30, 2024 08:49
-
-
Save esshka/eb398e3a6be4138dac7136ce03b56217 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import asyncio | |
import websockets | |
from typing import Dict, List, Optional, Set | |
from datetime import datetime | |
import hmac | |
import base64 | |
import time | |
import redis | |
import os | |
from src.market_data.redis_sliding_buffer import RedisSlidingBuffer | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
class CandlestickWebSocketClient: | |
def __init__(self, api_key: str, api_secret: str, api_passphrase: str, sliding_buffer: Optional[RedisSlidingBuffer] = None): | |
self.api_key = api_key | |
self.api_secret = api_secret | |
self.api_passphrase = api_passphrase | |
self.ws_url = "wss://ws.okx.com:8443/ws/v5/business" | |
self.subscriptions: Set[str] = set() # Track active subscriptions | |
self.websocket = None | |
self.is_connected = False | |
self.ping_task = None | |
self.reconnect_task = None | |
self.subscription_responses = asyncio.Queue() # Queue for subscription responses | |
# Initialize Redis sliding buffer | |
redis_client = redis.Redis( | |
host=os.getenv('REDIS_HOST', 'localhost'), | |
port=int(os.getenv('REDIS_PORT', 6379)), | |
db=0, | |
decode_responses=True | |
) | |
self.sliding_buffer = sliding_buffer or RedisSlidingBuffer(redis_client) | |
logger.debug("CandlestickWebSocketClient initialized") | |
def _get_timestamp(self) -> str: | |
"""Get ISO timestamp in the format required by OKX.""" | |
return str(int(time.time())) | |
def _get_signature(self, timestamp: str, method: str = 'GET', request_path: str = '/users/self/verify') -> str: | |
"""Generate signature for authentication.""" | |
message = timestamp + method + request_path | |
mac = hmac.new( | |
bytes(self.api_secret, encoding='utf8'), | |
bytes(message, encoding='utf-8'), | |
digestmod='sha256' | |
) | |
return base64.b64encode(mac.digest()).decode() | |
async def _authenticate(self) -> Dict: | |
"""Generate authentication message.""" | |
try: | |
timestamp = self._get_timestamp() | |
logger.debug(f"Generated timestamp: {timestamp}") | |
signature = self._get_signature(timestamp) | |
logger.debug("Generated signature") | |
auth_message = { | |
"op": "login", | |
"args": [{ | |
"apiKey": self.api_key, | |
"passphrase": self.api_passphrase, | |
"timestamp": timestamp, | |
"sign": signature | |
}] | |
} | |
logger.debug(f"Generated complete auth message: {json.dumps(auth_message, indent=2)}") | |
return auth_message | |
except Exception as e: | |
logger.error(f"Error generating authentication message: {str(e)}", exc_info=True) | |
raise | |
async def connect(self): | |
"""Establish WebSocket connection.""" | |
try: | |
logger.info(f"Connecting to WebSocket URL: {self.ws_url}") | |
# Add connection options for better stability | |
self.websocket = await websockets.connect( | |
self.ws_url, | |
ping_interval=20, | |
ping_timeout=10, | |
close_timeout=10, | |
max_size=2**23, # 8MB max message size | |
compression=None, # Disable compression for better performance | |
extra_headers={ | |
'User-Agent': 'OKX-Trading-Bot/1.0' | |
} | |
) | |
self.is_connected = True | |
logger.info("WebSocket connection object created") | |
# Authenticate | |
auth_message = await self._authenticate() | |
logger.info("Generated authentication message") | |
logger.debug(f"Auth message content: {json.dumps(auth_message, indent=2)}") | |
# Clear any existing messages from the subscription responses queue | |
while not self.subscription_responses.empty(): | |
try: | |
self.subscription_responses.get_nowait() | |
except asyncio.QueueEmpty: | |
break | |
logger.info("Sending authentication message...") | |
await self.websocket.send(json.dumps(auth_message)) | |
logger.info("Authentication message sent, waiting for response...") | |
# Wait for auth response with timeout | |
try: | |
response = await asyncio.wait_for(self.websocket.recv(), timeout=10.0) | |
logger.info("Received auth response") | |
logger.debug(f"Raw auth response: {response}") | |
auth_response = json.loads(response) | |
logger.info(f"Authentication response: {json.dumps(auth_response, indent=2)}") | |
if auth_response.get('event') == 'login' and auth_response.get('code') == '0': | |
logger.info("β Authentication successful") | |
# Start ping task | |
logger.info("Starting ping task...") | |
self.ping_task = asyncio.create_task(self._ping_loop()) | |
logger.info("Ping task started") | |
# Start reconnection monitor | |
logger.info("Starting reconnection monitor...") | |
self.reconnect_task = asyncio.create_task(self._monitor_connection()) | |
logger.info("Reconnection monitor started") | |
logger.info("β WebSocket connection fully established and authenticated") | |
return True | |
else: | |
logger.error(f"β Authentication failed: {json.dumps(auth_response, indent=2)}") | |
await self.close() | |
return False | |
except asyncio.TimeoutError: | |
logger.error("β Authentication response timeout") | |
await self.close() | |
return False | |
except Exception as e: | |
logger.error(f"Error connecting to WebSocket: {str(e)}", exc_info=True) | |
self.is_connected = False | |
await self.close() | |
return False | |
async def _ping_loop(self): | |
"""Send periodic ping messages to keep connection alive.""" | |
ping_count = 0 | |
last_pong_time = time.time() | |
while self.is_connected: | |
try: | |
if not self.websocket: | |
logger.error("WebSocket connection is None in ping loop") | |
break | |
current_time = time.time() | |
# Check if we haven't received a pong in more than 30 seconds | |
if current_time - last_pong_time > 30: | |
logger.error("No pong received in 30 seconds, reconnecting...") | |
self.is_connected = False | |
await self.reconnect() | |
last_pong_time = current_time | |
continue | |
ping_count += 1 | |
logger.debug(f"Sending ping message #{ping_count}") | |
await self.websocket.send('ping') | |
logger.debug(f"Ping #{ping_count} sent successfully") | |
# Update last_pong_time when we successfully send a ping | |
last_pong_time = current_time | |
await asyncio.sleep(15) # Send ping every 15 seconds | |
except websockets.exceptions.ConnectionClosed: | |
logger.error("WebSocket connection closed in ping loop") | |
self.is_connected = False | |
break | |
except Exception as e: | |
logger.error(f"Error in ping loop: {str(e)}") | |
break | |
async def _monitor_connection(self): | |
"""Monitor connection and reconnect if necessary.""" | |
while True: | |
try: | |
if not self.is_connected or not self.websocket: | |
logger.info("Connection lost, attempting to reconnect...") | |
await self.reconnect() | |
await asyncio.sleep(5) | |
except websockets.exceptions.ConnectionClosed: | |
logger.error("WebSocket connection closed in monitor") | |
self.is_connected = False | |
await self.reconnect() | |
except Exception as e: | |
logger.error(f"Error in connection monitor: {str(e)}") | |
await asyncio.sleep(5) # Wait before retrying | |
async def reconnect(self): | |
"""Reconnect to WebSocket and resubscribe to channels.""" | |
try: | |
await self.close() | |
await self.connect() | |
# Resubscribe to all active channels | |
for subscription in self.subscriptions: | |
await self.subscribe_candlesticks(subscription) | |
except Exception as e: | |
logger.error(f"Error reconnecting: {str(e)}") | |
async def close(self): | |
"""Close WebSocket connection and cleanup tasks.""" | |
try: | |
if self.ping_task: | |
self.ping_task.cancel() | |
if self.reconnect_task: | |
self.reconnect_task.cancel() | |
if self.websocket: | |
await self.websocket.close() | |
self.is_connected = False | |
except Exception as e: | |
logger.error(f"Error closing connection: {str(e)}") | |
async def subscribe_candlesticks(self, inst_id: str, timeframe: str = "1m"): | |
"""Subscribe to candlestick data for an instrument.""" | |
max_reconnect_attempts = 3 | |
reconnect_attempt = 0 | |
while reconnect_attempt < max_reconnect_attempts: | |
try: | |
logger.info(f"Starting subscription process for {inst_id} with timeframe {timeframe} (attempt {reconnect_attempt + 1}/{max_reconnect_attempts})") | |
logger.info(f"Current subscriptions: {self.subscriptions}") | |
if not self.is_connected: | |
logger.info("Not connected, attempting to connect first...") | |
if not await self.connect(): | |
logger.error("Failed to establish WebSocket connection") | |
reconnect_attempt += 1 | |
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff | |
continue | |
# Verify connection is still alive | |
if not self.websocket or not self.is_connected: | |
logger.error("WebSocket connection lost during subscription") | |
reconnect_attempt += 1 | |
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff | |
continue | |
# OKX uses different timeframe format | |
bar_map = { | |
"1m": "1m", | |
"3m": "3m", | |
"5m": "5m", | |
"15m": "15m", | |
"30m": "30m", | |
"1H": "1H", | |
"2H": "2H", | |
"4H": "4H" | |
} | |
bar = bar_map.get(timeframe, "1m") | |
logger.info(f"Mapped timeframe {timeframe} to {bar}") | |
# Format subscription message | |
subscription_msg = { | |
"op": "subscribe", | |
"args": [{ | |
"channel": f"candle{bar}", | |
"instId": inst_id.upper() # OKX requires uppercase | |
}] | |
} | |
logger.info("Preparing to send subscription message") | |
logger.debug(f"Subscription message content: {json.dumps(subscription_msg, indent=2)}") | |
# Clear any existing messages from the subscription responses queue | |
while not self.subscription_responses.empty(): | |
try: | |
self.subscription_responses.get_nowait() | |
except asyncio.QueueEmpty: | |
break | |
# Send subscription request | |
logger.info("Sending subscription message...") | |
await self.websocket.send(json.dumps(subscription_msg)) | |
logger.info("Subscription message sent successfully") | |
# Wait for subscription confirmation or first data message | |
subscription_confirmed = False | |
start_time = time.time() | |
timeout = 30 # Increased timeout to 30 seconds | |
retries = 3 # Number of retries for subscription | |
logger.info(f"Starting subscription confirmation loop with {timeout}s timeout and {retries} retries") | |
for attempt in range(retries): | |
try: | |
while time.time() - start_time < timeout and not subscription_confirmed: | |
try: | |
logger.info(f"Attempt {attempt + 1}/{retries}: Waiting for response from subscription_responses queue") | |
response_data = await asyncio.wait_for(self.subscription_responses.get(), timeout=10.0) | |
logger.info(f"Got subscription response: {json.dumps(response_data, indent=2)}") | |
# Handle subscription confirmation | |
if 'event' in response_data: | |
if response_data['event'] == 'subscribe': | |
if not response_data.get('msg'): # No error message means success | |
subscription_key = f"{inst_id.upper()}:{timeframe}" | |
self.subscriptions.add(subscription_key) | |
logger.info(f"Successfully subscribed to candle{bar} for {inst_id}") | |
logger.info(f"Updated subscriptions: {self.subscriptions}") | |
subscription_confirmed = True | |
return True | |
else: | |
error_msg = response_data.get('msg', 'Unknown error') | |
logger.error(f"Subscription failed - Message: {error_msg}") | |
if 'invalid channel' in error_msg.lower(): | |
logger.error("Invalid channel specified, check timeframe format") | |
return False | |
break # Break inner loop to try again | |
elif response_data['event'] == 'error': | |
logger.error(f"Received error event: {json.dumps(response_data, indent=2)}") | |
if 'auth' in str(response_data.get('msg', '')).lower(): | |
logger.error("Authentication error, attempting to reconnect...") | |
await self.reconnect() | |
break # Break inner loop to try again | |
# Handle data message as successful subscription | |
elif 'data' in response_data: | |
channel = response_data.get('arg', {}).get('channel', '') | |
received_inst_id = response_data.get('arg', {}).get('instId', '').upper() | |
logger.info(f"Received data message with channel: {channel}, instId: {received_inst_id}") | |
if channel.startswith('candle') and received_inst_id == inst_id.upper(): | |
subscription_key = f"{inst_id.upper()}:{timeframe}" | |
self.subscriptions.add(subscription_key) | |
logger.info(f"Subscription confirmed via data message for {inst_id}") | |
logger.info(f"Updated subscriptions: {self.subscriptions}") | |
subscription_confirmed = True | |
return True | |
else: | |
logger.warning(f"Received data message but channel or instId doesn't match: {channel} != candle{bar} or {received_inst_id} != {inst_id.upper()}") | |
except asyncio.TimeoutError: | |
logger.warning(f"Timeout waiting for subscription confirmation on attempt {attempt + 1}/{retries}") | |
# Check if connection is still alive | |
if not self.is_connected or not self.websocket: | |
logger.error("WebSocket connection lost during subscription wait") | |
break # Break inner loop to trigger reconnection | |
break # Break inner loop to try again | |
if subscription_confirmed: | |
break | |
if attempt < retries - 1: | |
logger.info(f"Retrying subscription (attempt {attempt + 2}/{retries})...") | |
# Verify connection before retrying | |
if not self.is_connected or not self.websocket: | |
logger.error("WebSocket connection lost before retry") | |
break # Break to trigger reconnection | |
await self.websocket.send(json.dumps(subscription_msg)) | |
except Exception as e: | |
logger.error(f"Error in subscription attempt {attempt + 1}: {str(e)}") | |
if attempt < retries - 1: | |
continue | |
raise | |
if not subscription_confirmed: | |
logger.error(f"Subscription confirmation timeout for {inst_id} after {retries} attempts") | |
reconnect_attempt += 1 | |
if reconnect_attempt < max_reconnect_attempts: | |
logger.info(f"Attempting reconnection {reconnect_attempt + 1}/{max_reconnect_attempts}") | |
await self.reconnect() | |
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff | |
continue | |
return False | |
except Exception as e: | |
logger.error(f"Error in subscription process: {str(e)}", exc_info=True) | |
reconnect_attempt += 1 | |
if reconnect_attempt < max_reconnect_attempts: | |
logger.info(f"Attempting reconnection {reconnect_attempt + 1}/{max_reconnect_attempts}") | |
await self.reconnect() | |
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff | |
continue | |
return False | |
logger.error(f"Failed to subscribe to {inst_id} after {max_reconnect_attempts} reconnection attempts") | |
return False | |
async def _process_candlestick_data(self, message: dict): | |
"""Process candlestick data from WebSocket and publish to Redis.""" | |
try: | |
if 'data' in message and isinstance(message['data'], list): | |
trading_pair = message['arg']['instId'] | |
timeframe = message['arg']['channel'].replace('candle', '') | |
logger.info(f"Processing candlestick data for {trading_pair}") | |
for candle in message['data']: | |
# OKX candlestick format: [timestamp, open, high, low, close, vol, volCcy] | |
if isinstance(candle, list) and len(candle) >= 7: | |
candlestick_data = { | |
'timestamp': candle[0], | |
'open': float(candle[1]), | |
'high': float(candle[2]), | |
'low': float(candle[3]), | |
'close': float(candle[4]), | |
'volume': float(candle[5]), | |
'volCcy': float(candle[6]) | |
} | |
# Push data to sliding buffer | |
success = await self.sliding_buffer.push_data( | |
trading_pair=trading_pair, | |
timeframe=timeframe, | |
data=candlestick_data | |
) | |
if not success: | |
logger.error(f"Failed to push candlestick data for {trading_pair}") | |
except Exception as e: | |
logger.error(f"Error processing candlestick data: {str(e)}", exc_info=True) | |
async def process_messages(self): | |
"""Process incoming WebSocket messages.""" | |
try: | |
logger.info("π Starting WebSocket message processing...") | |
message_count = 0 | |
logger.info(f"Connection state: {self.is_connected}") | |
logger.info(f"WebSocket object: {'Connected' if self.websocket else 'Not connected'}") | |
logger.info(f"Active subscriptions: {self.subscriptions}") | |
while self.is_connected: | |
try: | |
if not self.websocket: | |
logger.error("β WebSocket connection is None in message processing") | |
break | |
message = await self.websocket.recv() | |
message_count += 1 | |
logger.debug(f"π₯ Message #{message_count} received: {message}") | |
if message == 'pong': | |
logger.debug(f"Received pong message #{message_count}") | |
continue | |
try: | |
data = json.loads(message) | |
message_type = 'event' if 'event' in data else 'data' if 'data' in data else 'unknown' | |
logger.debug(f"Parsed message type: {message_type}") | |
# Route subscription responses to subscription queue | |
if message_type == 'event': | |
if data['event'] == 'subscribe': | |
logger.info(f"π¬ Subscription response received: {json.dumps(data, indent=2)}") | |
await self.subscription_responses.put(data) | |
continue | |
elif data['event'] == 'error': | |
logger.error(f"β Error event received: {json.dumps(data, indent=2)}") | |
await self.subscription_responses.put(data) | |
continue | |
# Handle candlestick data | |
if message_type == 'data' and 'arg' in data and data['arg'].get('channel', '').startswith('candle'): | |
inst_id = data['arg'].get('instId', 'unknown') | |
logger.debug(f"π Candlestick data received for {inst_id}") | |
# Always put data message in subscription responses queue first | |
# This helps with initial subscription confirmation | |
await self.subscription_responses.put(data) | |
# Process and publish the data | |
await self._process_candlestick_data(data) | |
else: | |
logger.debug(f"Other message type received: {json.dumps(data, indent=2)}") | |
except json.JSONDecodeError: | |
logger.error(f"β Failed to parse message #{message_count} as JSON: {message}") | |
except Exception as e: | |
logger.error(f"β Error processing message #{message_count}: {str(e)}", exc_info=True) | |
except websockets.exceptions.ConnectionClosed: | |
logger.error("β WebSocket connection closed in message processing") | |
self.is_connected = False | |
break | |
except Exception as e: | |
logger.error(f"β Error in message processing loop: {str(e)}", exc_info=True) | |
if not self.is_connected: | |
break | |
await asyncio.sleep(1) | |
except Exception as e: | |
logger.error(f"β Fatal error in message processing: {str(e)}", exc_info=True) | |
self.is_connected = False | |
raise | |
# Create singleton instance | |
candlestick_ws_client = None | |
def initialize_ws_client(api_key: str, api_secret: str, api_passphrase: str): | |
"""Initialize the WebSocket client singleton.""" | |
global candlestick_ws_client | |
if not candlestick_ws_client: | |
candlestick_ws_client = CandlestickWebSocketClient(api_key, api_secret, api_passphrase) | |
return candlestick_ws_client |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment