Skip to content

Instantly share code, notes, and snippets.

@haijohn
Created May 8, 2025 15:22
Show Gist options
  • Save haijohn/2b95360b17faa156c8e21aaa232f16c2 to your computer and use it in GitHub Desktop.
Save haijohn/2b95360b17faa156c8e21aaa232f16c2 to your computer and use it in GitHub Desktop.
demo the master/slave code
# combined_pipeline.py
import multiprocessing
import time
import os
import pickle
import pandas as pd
import sys
import functools # Import functools for decorators
# --- Configuration ---
# Base file names for results and state (will be appended with worker ID)
WORKER_RESULT_FILE_BASE = "worker_inference_result_"
WORKER_STATE_FILE_BASE = "worker_state_"
# Polling interval for the coordinator to check for worker state/result files (in seconds)
POLLING_INTERVAL = 2
# Timeout for the coordinator waiting for ALL workers (in seconds)
COORDINATOR_TIMEOUT = 30 # Coordinator will wait up to 30 seconds for all workers
# --- Worker States ---
STATE_STARTED = "started"
STATE_INFERENCE_COMPLETE = "inference_complete"
STATE_RESULTS_SAVED = "results_saved" # State indicating results are saved to file
STATE_FAILED = "failed"
# --- Number of Workers ---
# This should be consistent across coordinator and worker launches
NUM_WORKERS = 3 # Set the desired number of worker processes
# --- File Utility Functions ---
def get_worker_result_file(worker_id: int) -> str:
"""Returns the result file path for a given worker ID."""
return f"{WORKER_RESULT_FILE_BASE}{worker_id}.pkl"
def get_worker_state_file(worker_id: int) -> str:
"""Returns the state file path for a given worker ID."""
return f"{WORKER_STATE_FILE_BASE}{worker_id}.pkl"
def save_state_to_file(worker_id: int, state: str):
"""Saves the current state for a specific worker to its state file."""
state_file = get_worker_state_file(worker_id)
try:
with open(state_file, 'wb') as f:
pickle.dump(state, f)
print(f"Process {os.getpid()}: Worker {worker_id}: State saved to {state_file}: {state}")
except Exception as e:
print(f"Process {os.getpid()}: Worker {worker_id}: Error saving state '{state}' to file {state_file}: {e}")
def load_state_from_file(worker_id: int) -> str | None:
"""Loads the current state for a specific worker from its state file."""
state_file = get_worker_state_file(worker_id)
if not os.path.exists(state_file):
return None
try:
with open(state_file, 'rb') as f:
state = pickle.load(f)
return state
except Exception as e:
print(f"Process {os.getpid()}: Error loading state from file {state_file}: {e}")
return None
def save_result_to_file(worker_id: int, result: pd.DataFrame):
"""Saves the inference result for a specific worker to its result file."""
result_file = get_worker_result_file(worker_id)
try:
with open(result_file, 'wb') as f:
pickle.dump(result, f)
print(f"Process {os.getpid()}: Worker {worker_id}: Results saved to {result_file}")
except Exception as e:
print(f"Process {os.getpid()}: Worker {worker_id}: Error saving results to file {result_file}: {e}")
# If saving fails, we should indicate failure
# Note: In this design, the calling decorator handles setting STATE_FAILED
# if save_result_to_file is called within its try block and fails.
# However, if save_result_to_file is called directly and fails,
# this catch block ensures the state is updated.
save_state_to_file(worker_id, STATE_FAILED)
raise # Re-raise the exception
def load_result_from_file(worker_id: int) -> pd.DataFrame | None:
"""Loads the inference result for a specific worker from its result file."""
result_file = get_worker_result_file(worker_id)
if not os.path.exists(result_file):
return None
try:
with open(result_file, 'rb') as f:
result = pickle.load(f)
return result
except Exception as e:
print(f"Process {os.getpid()}: Error loading results from file {result_file}: {e}")
return None
def wait_for_all_workers_completion(num_workers: int = NUM_WORKERS, timeout: int = COORDINATOR_TIMEOUT, polling_interval: int = POLLING_INTERVAL) -> bool:
"""
Waits for all worker jobs to complete by monitoring their state files.
Returns True if all workers complete successfully within timeout, False otherwise.
"""
print(f"Process {os.getpid()}: Coordinator: Waiting for {num_workers} workers to reach state '{STATE_RESULTS_SAVED}' with timeout {timeout}s...")
start_time = time.time()
workers_completed = set() # Keep track of which workers have completed
# Worker IDs will be from 2 to num_workers + 1
worker_ids_to_wait_for = set(range(2, num_workers + 2))
while time.time() - start_time < timeout:
all_workers_ready = True
for worker_id in worker_ids_to_wait_for:
if worker_id in workers_completed:
continue # This worker is already done
worker_status = load_state_from_file(worker_id)
if worker_status == STATE_FAILED:
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} reported failure via state file. Aborting wait.")
return False # Any worker failure means overall failure
if worker_status == STATE_RESULTS_SAVED:
# Check if the result file also exists
if os.path.exists(get_worker_result_file(worker_id)):
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} reported saved results and file found.")
workers_completed.add(worker_id)
else:
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} state is '{STATE_RESULTS_SAVED}' but result file not found yet. Waiting...")
all_workers_ready = False # This worker is not fully ready yet
else:
# If state is not 'RESULTS_SAVED' or 'FAILED', it's not ready
# print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} current state: {worker_status}. Waiting...") # Too verbose
all_workers_ready = False
if all_workers_ready and len(workers_completed) == num_workers:
print(f"Process {os.getpid()}: Coordinator: All {num_workers} workers have successfully completed.")
return True # All workers are done and results are available
time.sleep(polling_interval)
# Timeout reached
print(f"Process {os.getpid()}: Coordinator: Timeout reached ({timeout}s) or not all workers completed successfully.")
return False
def clean_up_worker_files(num_workers: int = NUM_WORKERS):
"""Cleans up state and result files for all workers."""
print("Cleaning up previous worker state and result files...")
for worker_id in range(2, num_workers + 2):
state_file = get_worker_state_file(worker_id)
result_file = get_worker_result_file(worker_id)
if os.path.exists(state_file):
os.remove(state_file)
print(f"Cleaned up {state_file}")
if os.path.exists(result_file):
os.remove(result_file)
print(f"Cleaned up {result_file}")
print("Cleanup complete.")
def load_all_worker_results(num_workers: int = NUM_WORKERS) -> list[pd.DataFrame] | None:
"""
Loads inference results from all worker result files.
Returns a list of DataFrames or None if loading fails for any worker.
"""
print(f"Process {os.getpid()}: Coordinator: Loading results from {num_workers} worker files...")
worker_results = []
# Worker IDs are from 2 to num_workers + 1
for worker_id in range(2, num_workers + 2):
try:
worker_inference_result = load_result_from_file(worker_id)
if worker_inference_result is None:
print(f"Process {os.getpid()}: Coordinator: Failed to load result from worker {worker_id} file.")
return None # Return None if any worker result file is missing/corrupt
# The 'partition_id' is already in the DataFrame loaded from the file.
worker_results.append(worker_inference_result)
print(f"Process {os.getpid()}: Coordinator: Loaded result from worker {worker_id}.")
except Exception as e:
print(f"Process {os.getpid()}: Coordinator: Error loading results from worker {worker_id} file: {e}")
return None # Return None if any error occurs during loading
print(f"Process {os.getpid()}: Coordinator: All {num_workers} worker results loaded successfully.")
return worker_results
# --- Decorators ---
# Removed the log_inference_timing decorator as requested.
def update_worker_inference_state_and_save_result(func):
"""
A decorator that updates the worker's state file before and after
the inference function runs, and saves the result to file.
Only applies to workers (partition_id > 1).
"""
@functools.wraps(func) # Use functools.wraps to preserve function metadata
def wrapper(*args, **kwargs):
# Assuming partition_id is the second argument
partition_id = args[1] if len(args) > 1 else 1 # Default to 1 if not provided
worker_id = partition_id # Worker ID is the same as partition ID
# Only update state for worker partitions (ID > 1)
if partition_id > 1:
save_state_to_file(worker_id, STATE_STARTED)
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_STARTED}")
try:
result = func(*args, **kwargs)
# Only update state and save result for worker partitions (ID > 1)
if partition_id > 1:
save_state_to_file(worker_id, STATE_INFERENCE_COMPLETE)
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_INFERENCE_COMPLETE}")
# Save the result to the worker's result file
save_result_to_file(worker_id, result)
# Update state: Results Saved (after saving the file)
save_state_to_file(worker_id, STATE_RESULTS_SAVED)
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_RESULTS_SAVED}")
return result
except Exception as e:
# If an error occurs and it's a worker partition, update state to FAILED
if partition_id > 1:
save_state_to_file(worker_id, STATE_FAILED)
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_FAILED} due to error.")
raise # Re-raise the exception
return wrapper
# --- Core Pipeline Functions ---
# Removed @log_inference_timing decorator
@update_worker_inference_state_and_save_result # Apply the state update and save decorator
def inference(data_partition: pd.DataFrame, partition_id: int, simulate_failure: bool = False) -> pd.DataFrame:
"""
Simulates the inference step of the ML pipeline.
Replace with your actual model inference logic.
Added simulate_failure for testing timeout/failure handling.
"""
# The logging, timing (removed), state updates, and result saving for workers
# are now handled by the decorators
# Simulate some work
sleep_duration = 3 + partition_id # Simulate different inference times
if simulate_failure and partition_id > 1: # Simulate failure only for workers (partition_id > 1)
time.sleep(sleep_duration / 2) # Fail halfway through
raise RuntimeError(f"Simulated failure in partition {partition_id} inference")
time.sleep(sleep_duration)
# Simulate generating some result data
result = data_partition.copy()
# Ensure column name is unique to the partition for clarity
result[f'inference_output_{partition_id}'] = result['value'] * 10
# Add partition ID to the result DataFrame here, as it's part of the result
# and will be saved by the decorator for workers.
result['partition_id'] = partition_id
return result
def postprocess(combined_results: pd.DataFrame):
"""
Simulates the postprocessing step.
Replace with your actual postprocessing logic.
"""
print(f"Process {os.getpid()}: Starting postprocessing...")
# Simulate some work
time.sleep(5)
# Simulate processing the combined results
# Ensure columns exist before accessing - need to check for all worker outputs
# Assuming 'partition_id' column exists in combined_results
if 'partition_id' in combined_results.columns:
output_cols = [f'inference_output_{i}' for i in combined_results['partition_id'].unique() if f'inference_output_{i}' in combined_results.columns]
if len(output_cols) > 0:
# Example: Summing up all inference outputs per row
combined_results['final_output'] = combined_results[output_cols].sum(axis=1)
else:
print(f"Process {os.getpid()}: Warning: Could not find any inference output columns for postprocessing.")
combined_results['final_output'] = pd.NA # Or some other default
else:
print(f"Process {os.getpid()}: Warning: 'partition_id' column not found in combined_results. Cannot perform postprocessing based on partition outputs.")
combined_results['final_output'] = pd.NA # Or some other default
print(f"Process {os.getpid()}: Postprocessing finished.")
print("\n--- Final Results ---")
print(combined_results)
# --- Combined Pipeline Job Function ---
def run_partition_job(data_partition: pd.DataFrame, partition_id: int, simulate_failure: bool = False):
"""
Runs the pipeline for a specific data partition based on its ID.
Partition ID 1 is the coordinator. Partition IDs > 1 are workers.
Uses files for state and result exchange.
"""
role = 'coordinator' if partition_id == 1 else 'worker'
print(f"Process {os.getpid()}: Running as {role} with Partition ID {partition_id}...")
# Run inference for this partition.
# The decorators handle state updates and result saving for workers.
try:
inference_result = inference(data_partition, partition_id, simulate_failure=simulate_failure)
# The 'partition_id' is added to the result DataFrame inside the inference function.
except Exception as e:
print(f"Process {os.getpid()}: Process {os.getpid()} (Partition {partition_id}): Error during inference: {e}")
sys.exit(1) # Exit the process if inference fails
if role == 'coordinator':
print(f"Process {os.getpid()}: Coordinator: Starting coordinator tasks...")
# 1. Coordinator's inference result is already computed and stored in inference_result
# 2. Wait for ALL workers' results to be available via state files with a timeout
# The decorator handles STATE_STARTED, STATE_INFERENCE_COMPLETE, STATE_FAILED,
# and STATE_RESULTS_SAVED for workers. We wait for STATE_RESULTS_SAVED here.
all_workers_successful = wait_for_all_workers_completion(NUM_WORKERS)
if not all_workers_successful:
print(f"Process {os.getpid()}: Coordinator: Not all workers completed successfully. Cannot proceed with postprocessing.")
# Handle the timeout/failure scenario - e.g., log error, send alert, process only coordinator data, etc.
sys.exit(1) # Exit if workers did not complete
# 3. Fetch ALL workers' results from their result files using the new function
worker_results = load_all_worker_results(NUM_WORKERS)
if worker_results is None: # load_all_worker_results returns None on failure
print(f"Process {os.getpid()}: Coordinator: Failed to load results from all workers. Cannot proceed with postprocessing.")
sys.exit(1)
# 4. Combine coordinator and ALL worker results
print(f"Process {os.getpid()}: Coordinator: Combining results from coordinator and {NUM_WORKERS} workers...")
# The coordinator's result is in inference_result
all_results = [inference_result] + worker_results
# Ensure columns are distinguishable if needed before concat (already handled in inference)
# Add a dummy id for demonstration if not already present
# This assumes original data had a consistent index or ID
# For simple concat, let's rely on the original index or add a new one
# Let's add a global index for clarity after combining
combined_results = pd.concat(all_results, ignore_index=True)
print(f"Process {os.getpid()}: Coordinator: Results combined.")
# 5. Run postprocessing on the combined results
postprocess(combined_results)
print(f"Process {os.getpid()}: Coordinator: Coordinator job finished successfully.")
elif role == 'worker':
worker_id = partition_id # Worker ID is the same as its partition ID
print(f"Process {os.getpid()}: Worker {worker_id}: Starting worker tasks...")
# The inference call above, with its decorators, handled the worker's tasks:
# running inference, updating state, saving result, and setting STATE_RESULTS_SAVED.
# No further steps are needed in the worker block after successful inference.
# If inference failed, the try/except block above would have caught it and exited.
# If inference was successful, the decorator saved the result and updated state.
print(f"Process {os.getpid()}: Worker {worker_id}: Worker job finished successfully.")
else:
print(f"Process {os.getpid()}: Invalid role/partition ID: {partition_id}. Must be 1 (coordinator) or > 1 (worker).")
sys.exit(1)
# --- Data Partitioning Helper ---
def partition_data(df: pd.DataFrame, num_partitions: int) -> list[pd.DataFrame]:
"""Partitions the input DataFrame into the specified number of parts."""
partition_size = len(df) // num_partitions
partitions = [df.iloc[i * partition_size:(i + 1) * partition_size].copy() for i in range(num_partitions)]
# Ensure all data is included even if not perfectly divisible
if len(df) % num_partitions != 0:
partitions[-1] = pd.concat([partitions[-1], df.iloc[(num_partitions - 1) * partition_size + len(partitions[-1]):].copy()], ignore_index=True)
return partitions
# --- Main Execution Block (for local testing) ---
# To run on separate nodes, you would launch this script on each node
# with the appropriate command-line arguments ('coordinator' or 'worker <partition_id>')
# and ensure the data partition is available to that script instance.
if __name__ == "__main__":
print("--- Distributed Pipeline Simulation (Local Test using Files) ---")
print("To run on separate nodes, ensure the state and result files are accessible")
print("by both nodes (e.g., on a shared network drive or cloud storage).")
print("Run this script on one node with 'python your_script_name.py coordinator'")
print("and on other nodes with 'python your_script_name.py worker <partition_id>'")
print("(where <partition_id> is 2, 3, ..., NUM_WORKERS + 1).")
print("-" * 60)
# --- Prepare Dummy Data ---
# Create a dummy dataset and split it into NUM_WORKERS + 1 partitions
data = {'value': range(100)} # Increased data size for more partitions
df = pd.DataFrame(data)
total_partitions = NUM_WORKERS + 1
data_partitions = partition_data(df, total_partitions)
print(f"Main: Data partitioned into {total_partitions} parts.")
for i, part in enumerate(data_partitions):
print(f"Main: Partition {i+1} size: {len(part)}")
# Clean up previous worker state and result files before starting
clean_up_worker_files(NUM_WORKERS)
print("\n--- Launching Coordinator and Worker Processes (Local Simulation) ---")
print("This simulates running them on different nodes by using multiprocessing.")
print("In a real distributed setup, you would launch this script separately")
print("on your different machines with the appropriate arguments.")
# --- Command Line Arguments to Determine Role and Partition ---
# This allows running the same script with different arguments
# to act as coordinator or a specific worker.
# Usage:
# python combined_pipeline.py coordinator
# python combined_pipeline.py worker 2 # For worker with partition ID 2
# python combined_pipeline.py worker 3 # For worker with partition ID 3, etc.
# python combined_pipeline.py local_test # For local simulation
if len(sys.argv) > 1:
role = sys.argv[1].lower()
if role == 'coordinator':
if len(data_partitions) < 1:
print("Error: Not enough data partitions for coordinator.")
sys.exit(1)
coordinator_partition_data = data_partitions[0] # Partition 1 data
# Run the combined pipeline job function as the coordinator
# In a real multi-node setup, you would run this directly
# without multiprocessing.Process if this script is the entry point.
# For local simulation, we use Process to mimic separation.
# run_partition_job(coordinator_partition_data, 1, simulate_failure=False) # For direct execution
coordinator_process = multiprocessing.Process(target=run_partition_job, args=(coordinator_partition_data, 1, False))
coordinator_process.start()
coordinator_process.join()
elif role == 'worker':
if len(sys.argv) > 2:
try:
partition_id = int(sys.argv[2])
if partition_id < 2 or partition_id > NUM_WORKERS + 1:
print(f"Invalid worker partition ID: {partition_id}. Must be between 2 and {NUM_WORKERS + 1}.")
sys.exit(1)
if partition_id > len(data_partitions):
print(f"Error: Data partition {partition_id} does not exist.")
sys.exit(1)
worker_partition_data = data_partitions[partition_id - 1] # Get the correct partition data (0-indexed)
# Set simulate_failure=True here for a specific worker ID to test failure
# simulate_failure = (partition_id == 2) # Example: Simulate failure for worker 2
simulate_failure = False # Default: no failure simulation
# Run the combined pipeline job function as a worker
# In a real multi-node setup, you would run this directly
# without multiprocessing.Process if this script is the entry point.
# For local simulation, we use Process to mimic separation.
# run_partition_job(worker_partition_data, partition_id, simulate_failure=simulate_failure) # For direct execution
worker_process = multiprocessing.Process(target=run_partition_job, args=(worker_partition_data, partition_id, simulate_failure))
worker_process.start()
worker_process.join()
except ValueError:
print("Invalid worker partition ID. Must be an integer.")
sys.exit(1)
else:
print("Please specify the worker partition ID (e.g., 'python combined_pipeline.py worker 2').")
sys.exit(1)
elif role == 'local_test':
# This option runs the coordinator and all workers locally using multiprocessing
print("\n--- Running Local Test Simulation ---")
# Launch the coordinator process
coordinator_process = multiprocessing.Process(target=run_partition_job, args=(data_partitions[0], 1, False)) # Coordinator is partition 1
coordinator_process.start()
# Launch worker processes
worker_processes = []
# Worker IDs are from 2 to NUM_WORKERS + 1
for i in range(NUM_WORKERS):
worker_id = i + 2 # Worker IDs start from 2
worker_partition_data = data_partitions[i + 1] # Workers get partitions from index 1 onwards
# Set simulate_failure=True for a specific worker ID to test failure
# simulate_failure = (worker_id == 2) # Example: Simulate failure for worker 2
simulate_failure = False # Default: no failure simulation
worker_process = multiprocessing.Process(target=run_partition_job, args=(worker_partition_data, worker_id, simulate_failure))
worker_processes.append(worker_process)
worker_process.start()
# Wait for Processes to Finish
coordinator_process.join()
for p in worker_processes:
p.join()
print("\nLocal test simulation finished.")
else:
print("Invalid role specified. Use 'coordinator', 'worker <partition_id>', or 'local_test'.")
sys.exit(1)
else:
print("Please specify the role to run: 'coordinator', 'worker <partition_id>', or 'local_test'.")
print("Example: python combined_pipeline.py coordinator")
print("Example: python combined_pipeline.py worker 2")
print("Example: python combined_pipeline.py local_test")
sys.exit(1)
print("Main: Script finished.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment