Skip to content

Instantly share code, notes, and snippets.

@esshka
Created December 30, 2024 08:49
Show Gist options
  • Save esshka/eb398e3a6be4138dac7136ce03b56217 to your computer and use it in GitHub Desktop.
Save esshka/eb398e3a6be4138dac7136ce03b56217 to your computer and use it in GitHub Desktop.
import json
import logging
import asyncio
import websockets
from typing import Dict, List, Optional, Set
from datetime import datetime
import hmac
import base64
import time
import redis
import os
from src.market_data.redis_sliding_buffer import RedisSlidingBuffer
# Configure logging
logger = logging.getLogger(__name__)
class CandlestickWebSocketClient:
def __init__(self, api_key: str, api_secret: str, api_passphrase: str, sliding_buffer: Optional[RedisSlidingBuffer] = None):
self.api_key = api_key
self.api_secret = api_secret
self.api_passphrase = api_passphrase
self.ws_url = "wss://ws.okx.com:8443/ws/v5/business"
self.subscriptions: Set[str] = set() # Track active subscriptions
self.websocket = None
self.is_connected = False
self.ping_task = None
self.reconnect_task = None
self.subscription_responses = asyncio.Queue() # Queue for subscription responses
# Initialize Redis sliding buffer
redis_client = redis.Redis(
host=os.getenv('REDIS_HOST', 'localhost'),
port=int(os.getenv('REDIS_PORT', 6379)),
db=0,
decode_responses=True
)
self.sliding_buffer = sliding_buffer or RedisSlidingBuffer(redis_client)
logger.debug("CandlestickWebSocketClient initialized")
def _get_timestamp(self) -> str:
"""Get ISO timestamp in the format required by OKX."""
return str(int(time.time()))
def _get_signature(self, timestamp: str, method: str = 'GET', request_path: str = '/users/self/verify') -> str:
"""Generate signature for authentication."""
message = timestamp + method + request_path
mac = hmac.new(
bytes(self.api_secret, encoding='utf8'),
bytes(message, encoding='utf-8'),
digestmod='sha256'
)
return base64.b64encode(mac.digest()).decode()
async def _authenticate(self) -> Dict:
"""Generate authentication message."""
try:
timestamp = self._get_timestamp()
logger.debug(f"Generated timestamp: {timestamp}")
signature = self._get_signature(timestamp)
logger.debug("Generated signature")
auth_message = {
"op": "login",
"args": [{
"apiKey": self.api_key,
"passphrase": self.api_passphrase,
"timestamp": timestamp,
"sign": signature
}]
}
logger.debug(f"Generated complete auth message: {json.dumps(auth_message, indent=2)}")
return auth_message
except Exception as e:
logger.error(f"Error generating authentication message: {str(e)}", exc_info=True)
raise
async def connect(self):
"""Establish WebSocket connection."""
try:
logger.info(f"Connecting to WebSocket URL: {self.ws_url}")
# Add connection options for better stability
self.websocket = await websockets.connect(
self.ws_url,
ping_interval=20,
ping_timeout=10,
close_timeout=10,
max_size=2**23, # 8MB max message size
compression=None, # Disable compression for better performance
extra_headers={
'User-Agent': 'OKX-Trading-Bot/1.0'
}
)
self.is_connected = True
logger.info("WebSocket connection object created")
# Authenticate
auth_message = await self._authenticate()
logger.info("Generated authentication message")
logger.debug(f"Auth message content: {json.dumps(auth_message, indent=2)}")
# Clear any existing messages from the subscription responses queue
while not self.subscription_responses.empty():
try:
self.subscription_responses.get_nowait()
except asyncio.QueueEmpty:
break
logger.info("Sending authentication message...")
await self.websocket.send(json.dumps(auth_message))
logger.info("Authentication message sent, waiting for response...")
# Wait for auth response with timeout
try:
response = await asyncio.wait_for(self.websocket.recv(), timeout=10.0)
logger.info("Received auth response")
logger.debug(f"Raw auth response: {response}")
auth_response = json.loads(response)
logger.info(f"Authentication response: {json.dumps(auth_response, indent=2)}")
if auth_response.get('event') == 'login' and auth_response.get('code') == '0':
logger.info("βœ… Authentication successful")
# Start ping task
logger.info("Starting ping task...")
self.ping_task = asyncio.create_task(self._ping_loop())
logger.info("Ping task started")
# Start reconnection monitor
logger.info("Starting reconnection monitor...")
self.reconnect_task = asyncio.create_task(self._monitor_connection())
logger.info("Reconnection monitor started")
logger.info("βœ… WebSocket connection fully established and authenticated")
return True
else:
logger.error(f"❌ Authentication failed: {json.dumps(auth_response, indent=2)}")
await self.close()
return False
except asyncio.TimeoutError:
logger.error("❌ Authentication response timeout")
await self.close()
return False
except Exception as e:
logger.error(f"Error connecting to WebSocket: {str(e)}", exc_info=True)
self.is_connected = False
await self.close()
return False
async def _ping_loop(self):
"""Send periodic ping messages to keep connection alive."""
ping_count = 0
last_pong_time = time.time()
while self.is_connected:
try:
if not self.websocket:
logger.error("WebSocket connection is None in ping loop")
break
current_time = time.time()
# Check if we haven't received a pong in more than 30 seconds
if current_time - last_pong_time > 30:
logger.error("No pong received in 30 seconds, reconnecting...")
self.is_connected = False
await self.reconnect()
last_pong_time = current_time
continue
ping_count += 1
logger.debug(f"Sending ping message #{ping_count}")
await self.websocket.send('ping')
logger.debug(f"Ping #{ping_count} sent successfully")
# Update last_pong_time when we successfully send a ping
last_pong_time = current_time
await asyncio.sleep(15) # Send ping every 15 seconds
except websockets.exceptions.ConnectionClosed:
logger.error("WebSocket connection closed in ping loop")
self.is_connected = False
break
except Exception as e:
logger.error(f"Error in ping loop: {str(e)}")
break
async def _monitor_connection(self):
"""Monitor connection and reconnect if necessary."""
while True:
try:
if not self.is_connected or not self.websocket:
logger.info("Connection lost, attempting to reconnect...")
await self.reconnect()
await asyncio.sleep(5)
except websockets.exceptions.ConnectionClosed:
logger.error("WebSocket connection closed in monitor")
self.is_connected = False
await self.reconnect()
except Exception as e:
logger.error(f"Error in connection monitor: {str(e)}")
await asyncio.sleep(5) # Wait before retrying
async def reconnect(self):
"""Reconnect to WebSocket and resubscribe to channels."""
try:
await self.close()
await self.connect()
# Resubscribe to all active channels
for subscription in self.subscriptions:
await self.subscribe_candlesticks(subscription)
except Exception as e:
logger.error(f"Error reconnecting: {str(e)}")
async def close(self):
"""Close WebSocket connection and cleanup tasks."""
try:
if self.ping_task:
self.ping_task.cancel()
if self.reconnect_task:
self.reconnect_task.cancel()
if self.websocket:
await self.websocket.close()
self.is_connected = False
except Exception as e:
logger.error(f"Error closing connection: {str(e)}")
async def subscribe_candlesticks(self, inst_id: str, timeframe: str = "1m"):
"""Subscribe to candlestick data for an instrument."""
max_reconnect_attempts = 3
reconnect_attempt = 0
while reconnect_attempt < max_reconnect_attempts:
try:
logger.info(f"Starting subscription process for {inst_id} with timeframe {timeframe} (attempt {reconnect_attempt + 1}/{max_reconnect_attempts})")
logger.info(f"Current subscriptions: {self.subscriptions}")
if not self.is_connected:
logger.info("Not connected, attempting to connect first...")
if not await self.connect():
logger.error("Failed to establish WebSocket connection")
reconnect_attempt += 1
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff
continue
# Verify connection is still alive
if not self.websocket or not self.is_connected:
logger.error("WebSocket connection lost during subscription")
reconnect_attempt += 1
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff
continue
# OKX uses different timeframe format
bar_map = {
"1m": "1m",
"3m": "3m",
"5m": "5m",
"15m": "15m",
"30m": "30m",
"1H": "1H",
"2H": "2H",
"4H": "4H"
}
bar = bar_map.get(timeframe, "1m")
logger.info(f"Mapped timeframe {timeframe} to {bar}")
# Format subscription message
subscription_msg = {
"op": "subscribe",
"args": [{
"channel": f"candle{bar}",
"instId": inst_id.upper() # OKX requires uppercase
}]
}
logger.info("Preparing to send subscription message")
logger.debug(f"Subscription message content: {json.dumps(subscription_msg, indent=2)}")
# Clear any existing messages from the subscription responses queue
while not self.subscription_responses.empty():
try:
self.subscription_responses.get_nowait()
except asyncio.QueueEmpty:
break
# Send subscription request
logger.info("Sending subscription message...")
await self.websocket.send(json.dumps(subscription_msg))
logger.info("Subscription message sent successfully")
# Wait for subscription confirmation or first data message
subscription_confirmed = False
start_time = time.time()
timeout = 30 # Increased timeout to 30 seconds
retries = 3 # Number of retries for subscription
logger.info(f"Starting subscription confirmation loop with {timeout}s timeout and {retries} retries")
for attempt in range(retries):
try:
while time.time() - start_time < timeout and not subscription_confirmed:
try:
logger.info(f"Attempt {attempt + 1}/{retries}: Waiting for response from subscription_responses queue")
response_data = await asyncio.wait_for(self.subscription_responses.get(), timeout=10.0)
logger.info(f"Got subscription response: {json.dumps(response_data, indent=2)}")
# Handle subscription confirmation
if 'event' in response_data:
if response_data['event'] == 'subscribe':
if not response_data.get('msg'): # No error message means success
subscription_key = f"{inst_id.upper()}:{timeframe}"
self.subscriptions.add(subscription_key)
logger.info(f"Successfully subscribed to candle{bar} for {inst_id}")
logger.info(f"Updated subscriptions: {self.subscriptions}")
subscription_confirmed = True
return True
else:
error_msg = response_data.get('msg', 'Unknown error')
logger.error(f"Subscription failed - Message: {error_msg}")
if 'invalid channel' in error_msg.lower():
logger.error("Invalid channel specified, check timeframe format")
return False
break # Break inner loop to try again
elif response_data['event'] == 'error':
logger.error(f"Received error event: {json.dumps(response_data, indent=2)}")
if 'auth' in str(response_data.get('msg', '')).lower():
logger.error("Authentication error, attempting to reconnect...")
await self.reconnect()
break # Break inner loop to try again
# Handle data message as successful subscription
elif 'data' in response_data:
channel = response_data.get('arg', {}).get('channel', '')
received_inst_id = response_data.get('arg', {}).get('instId', '').upper()
logger.info(f"Received data message with channel: {channel}, instId: {received_inst_id}")
if channel.startswith('candle') and received_inst_id == inst_id.upper():
subscription_key = f"{inst_id.upper()}:{timeframe}"
self.subscriptions.add(subscription_key)
logger.info(f"Subscription confirmed via data message for {inst_id}")
logger.info(f"Updated subscriptions: {self.subscriptions}")
subscription_confirmed = True
return True
else:
logger.warning(f"Received data message but channel or instId doesn't match: {channel} != candle{bar} or {received_inst_id} != {inst_id.upper()}")
except asyncio.TimeoutError:
logger.warning(f"Timeout waiting for subscription confirmation on attempt {attempt + 1}/{retries}")
# Check if connection is still alive
if not self.is_connected or not self.websocket:
logger.error("WebSocket connection lost during subscription wait")
break # Break inner loop to trigger reconnection
break # Break inner loop to try again
if subscription_confirmed:
break
if attempt < retries - 1:
logger.info(f"Retrying subscription (attempt {attempt + 2}/{retries})...")
# Verify connection before retrying
if not self.is_connected or not self.websocket:
logger.error("WebSocket connection lost before retry")
break # Break to trigger reconnection
await self.websocket.send(json.dumps(subscription_msg))
except Exception as e:
logger.error(f"Error in subscription attempt {attempt + 1}: {str(e)}")
if attempt < retries - 1:
continue
raise
if not subscription_confirmed:
logger.error(f"Subscription confirmation timeout for {inst_id} after {retries} attempts")
reconnect_attempt += 1
if reconnect_attempt < max_reconnect_attempts:
logger.info(f"Attempting reconnection {reconnect_attempt + 1}/{max_reconnect_attempts}")
await self.reconnect()
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff
continue
return False
except Exception as e:
logger.error(f"Error in subscription process: {str(e)}", exc_info=True)
reconnect_attempt += 1
if reconnect_attempt < max_reconnect_attempts:
logger.info(f"Attempting reconnection {reconnect_attempt + 1}/{max_reconnect_attempts}")
await self.reconnect()
await asyncio.sleep(2 ** reconnect_attempt) # Exponential backoff
continue
return False
logger.error(f"Failed to subscribe to {inst_id} after {max_reconnect_attempts} reconnection attempts")
return False
async def _process_candlestick_data(self, message: dict):
"""Process candlestick data from WebSocket and publish to Redis."""
try:
if 'data' in message and isinstance(message['data'], list):
trading_pair = message['arg']['instId']
timeframe = message['arg']['channel'].replace('candle', '')
logger.info(f"Processing candlestick data for {trading_pair}")
for candle in message['data']:
# OKX candlestick format: [timestamp, open, high, low, close, vol, volCcy]
if isinstance(candle, list) and len(candle) >= 7:
candlestick_data = {
'timestamp': candle[0],
'open': float(candle[1]),
'high': float(candle[2]),
'low': float(candle[3]),
'close': float(candle[4]),
'volume': float(candle[5]),
'volCcy': float(candle[6])
}
# Push data to sliding buffer
success = await self.sliding_buffer.push_data(
trading_pair=trading_pair,
timeframe=timeframe,
data=candlestick_data
)
if not success:
logger.error(f"Failed to push candlestick data for {trading_pair}")
except Exception as e:
logger.error(f"Error processing candlestick data: {str(e)}", exc_info=True)
async def process_messages(self):
"""Process incoming WebSocket messages."""
try:
logger.info("πŸ”„ Starting WebSocket message processing...")
message_count = 0
logger.info(f"Connection state: {self.is_connected}")
logger.info(f"WebSocket object: {'Connected' if self.websocket else 'Not connected'}")
logger.info(f"Active subscriptions: {self.subscriptions}")
while self.is_connected:
try:
if not self.websocket:
logger.error("❌ WebSocket connection is None in message processing")
break
message = await self.websocket.recv()
message_count += 1
logger.debug(f"πŸ“₯ Message #{message_count} received: {message}")
if message == 'pong':
logger.debug(f"Received pong message #{message_count}")
continue
try:
data = json.loads(message)
message_type = 'event' if 'event' in data else 'data' if 'data' in data else 'unknown'
logger.debug(f"Parsed message type: {message_type}")
# Route subscription responses to subscription queue
if message_type == 'event':
if data['event'] == 'subscribe':
logger.info(f"πŸ“¬ Subscription response received: {json.dumps(data, indent=2)}")
await self.subscription_responses.put(data)
continue
elif data['event'] == 'error':
logger.error(f"❌ Error event received: {json.dumps(data, indent=2)}")
await self.subscription_responses.put(data)
continue
# Handle candlestick data
if message_type == 'data' and 'arg' in data and data['arg'].get('channel', '').startswith('candle'):
inst_id = data['arg'].get('instId', 'unknown')
logger.debug(f"πŸ“Š Candlestick data received for {inst_id}")
# Always put data message in subscription responses queue first
# This helps with initial subscription confirmation
await self.subscription_responses.put(data)
# Process and publish the data
await self._process_candlestick_data(data)
else:
logger.debug(f"Other message type received: {json.dumps(data, indent=2)}")
except json.JSONDecodeError:
logger.error(f"❌ Failed to parse message #{message_count} as JSON: {message}")
except Exception as e:
logger.error(f"❌ Error processing message #{message_count}: {str(e)}", exc_info=True)
except websockets.exceptions.ConnectionClosed:
logger.error("❌ WebSocket connection closed in message processing")
self.is_connected = False
break
except Exception as e:
logger.error(f"❌ Error in message processing loop: {str(e)}", exc_info=True)
if not self.is_connected:
break
await asyncio.sleep(1)
except Exception as e:
logger.error(f"❌ Fatal error in message processing: {str(e)}", exc_info=True)
self.is_connected = False
raise
# Create singleton instance
candlestick_ws_client = None
def initialize_ws_client(api_key: str, api_secret: str, api_passphrase: str):
"""Initialize the WebSocket client singleton."""
global candlestick_ws_client
if not candlestick_ws_client:
candlestick_ws_client = CandlestickWebSocketClient(api_key, api_secret, api_passphrase)
return candlestick_ws_client
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment