Created
May 8, 2025 15:22
-
-
Save haijohn/2b95360b17faa156c8e21aaa232f16c2 to your computer and use it in GitHub Desktop.
demo the master/slave code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# combined_pipeline.py | |
import multiprocessing | |
import time | |
import os | |
import pickle | |
import pandas as pd | |
import sys | |
import functools # Import functools for decorators | |
# --- Configuration --- | |
# Base file names for results and state (will be appended with worker ID) | |
WORKER_RESULT_FILE_BASE = "worker_inference_result_" | |
WORKER_STATE_FILE_BASE = "worker_state_" | |
# Polling interval for the coordinator to check for worker state/result files (in seconds) | |
POLLING_INTERVAL = 2 | |
# Timeout for the coordinator waiting for ALL workers (in seconds) | |
COORDINATOR_TIMEOUT = 30 # Coordinator will wait up to 30 seconds for all workers | |
# --- Worker States --- | |
STATE_STARTED = "started" | |
STATE_INFERENCE_COMPLETE = "inference_complete" | |
STATE_RESULTS_SAVED = "results_saved" # State indicating results are saved to file | |
STATE_FAILED = "failed" | |
# --- Number of Workers --- | |
# This should be consistent across coordinator and worker launches | |
NUM_WORKERS = 3 # Set the desired number of worker processes | |
# --- File Utility Functions --- | |
def get_worker_result_file(worker_id: int) -> str: | |
"""Returns the result file path for a given worker ID.""" | |
return f"{WORKER_RESULT_FILE_BASE}{worker_id}.pkl" | |
def get_worker_state_file(worker_id: int) -> str: | |
"""Returns the state file path for a given worker ID.""" | |
return f"{WORKER_STATE_FILE_BASE}{worker_id}.pkl" | |
def save_state_to_file(worker_id: int, state: str): | |
"""Saves the current state for a specific worker to its state file.""" | |
state_file = get_worker_state_file(worker_id) | |
try: | |
with open(state_file, 'wb') as f: | |
pickle.dump(state, f) | |
print(f"Process {os.getpid()}: Worker {worker_id}: State saved to {state_file}: {state}") | |
except Exception as e: | |
print(f"Process {os.getpid()}: Worker {worker_id}: Error saving state '{state}' to file {state_file}: {e}") | |
def load_state_from_file(worker_id: int) -> str | None: | |
"""Loads the current state for a specific worker from its state file.""" | |
state_file = get_worker_state_file(worker_id) | |
if not os.path.exists(state_file): | |
return None | |
try: | |
with open(state_file, 'rb') as f: | |
state = pickle.load(f) | |
return state | |
except Exception as e: | |
print(f"Process {os.getpid()}: Error loading state from file {state_file}: {e}") | |
return None | |
def save_result_to_file(worker_id: int, result: pd.DataFrame): | |
"""Saves the inference result for a specific worker to its result file.""" | |
result_file = get_worker_result_file(worker_id) | |
try: | |
with open(result_file, 'wb') as f: | |
pickle.dump(result, f) | |
print(f"Process {os.getpid()}: Worker {worker_id}: Results saved to {result_file}") | |
except Exception as e: | |
print(f"Process {os.getpid()}: Worker {worker_id}: Error saving results to file {result_file}: {e}") | |
# If saving fails, we should indicate failure | |
# Note: In this design, the calling decorator handles setting STATE_FAILED | |
# if save_result_to_file is called within its try block and fails. | |
# However, if save_result_to_file is called directly and fails, | |
# this catch block ensures the state is updated. | |
save_state_to_file(worker_id, STATE_FAILED) | |
raise # Re-raise the exception | |
def load_result_from_file(worker_id: int) -> pd.DataFrame | None: | |
"""Loads the inference result for a specific worker from its result file.""" | |
result_file = get_worker_result_file(worker_id) | |
if not os.path.exists(result_file): | |
return None | |
try: | |
with open(result_file, 'rb') as f: | |
result = pickle.load(f) | |
return result | |
except Exception as e: | |
print(f"Process {os.getpid()}: Error loading results from file {result_file}: {e}") | |
return None | |
def wait_for_all_workers_completion(num_workers: int = NUM_WORKERS, timeout: int = COORDINATOR_TIMEOUT, polling_interval: int = POLLING_INTERVAL) -> bool: | |
""" | |
Waits for all worker jobs to complete by monitoring their state files. | |
Returns True if all workers complete successfully within timeout, False otherwise. | |
""" | |
print(f"Process {os.getpid()}: Coordinator: Waiting for {num_workers} workers to reach state '{STATE_RESULTS_SAVED}' with timeout {timeout}s...") | |
start_time = time.time() | |
workers_completed = set() # Keep track of which workers have completed | |
# Worker IDs will be from 2 to num_workers + 1 | |
worker_ids_to_wait_for = set(range(2, num_workers + 2)) | |
while time.time() - start_time < timeout: | |
all_workers_ready = True | |
for worker_id in worker_ids_to_wait_for: | |
if worker_id in workers_completed: | |
continue # This worker is already done | |
worker_status = load_state_from_file(worker_id) | |
if worker_status == STATE_FAILED: | |
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} reported failure via state file. Aborting wait.") | |
return False # Any worker failure means overall failure | |
if worker_status == STATE_RESULTS_SAVED: | |
# Check if the result file also exists | |
if os.path.exists(get_worker_result_file(worker_id)): | |
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} reported saved results and file found.") | |
workers_completed.add(worker_id) | |
else: | |
print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} state is '{STATE_RESULTS_SAVED}' but result file not found yet. Waiting...") | |
all_workers_ready = False # This worker is not fully ready yet | |
else: | |
# If state is not 'RESULTS_SAVED' or 'FAILED', it's not ready | |
# print(f"Process {os.getpid()}: Coordinator: Worker {worker_id} current state: {worker_status}. Waiting...") # Too verbose | |
all_workers_ready = False | |
if all_workers_ready and len(workers_completed) == num_workers: | |
print(f"Process {os.getpid()}: Coordinator: All {num_workers} workers have successfully completed.") | |
return True # All workers are done and results are available | |
time.sleep(polling_interval) | |
# Timeout reached | |
print(f"Process {os.getpid()}: Coordinator: Timeout reached ({timeout}s) or not all workers completed successfully.") | |
return False | |
def clean_up_worker_files(num_workers: int = NUM_WORKERS): | |
"""Cleans up state and result files for all workers.""" | |
print("Cleaning up previous worker state and result files...") | |
for worker_id in range(2, num_workers + 2): | |
state_file = get_worker_state_file(worker_id) | |
result_file = get_worker_result_file(worker_id) | |
if os.path.exists(state_file): | |
os.remove(state_file) | |
print(f"Cleaned up {state_file}") | |
if os.path.exists(result_file): | |
os.remove(result_file) | |
print(f"Cleaned up {result_file}") | |
print("Cleanup complete.") | |
def load_all_worker_results(num_workers: int = NUM_WORKERS) -> list[pd.DataFrame] | None: | |
""" | |
Loads inference results from all worker result files. | |
Returns a list of DataFrames or None if loading fails for any worker. | |
""" | |
print(f"Process {os.getpid()}: Coordinator: Loading results from {num_workers} worker files...") | |
worker_results = [] | |
# Worker IDs are from 2 to num_workers + 1 | |
for worker_id in range(2, num_workers + 2): | |
try: | |
worker_inference_result = load_result_from_file(worker_id) | |
if worker_inference_result is None: | |
print(f"Process {os.getpid()}: Coordinator: Failed to load result from worker {worker_id} file.") | |
return None # Return None if any worker result file is missing/corrupt | |
# The 'partition_id' is already in the DataFrame loaded from the file. | |
worker_results.append(worker_inference_result) | |
print(f"Process {os.getpid()}: Coordinator: Loaded result from worker {worker_id}.") | |
except Exception as e: | |
print(f"Process {os.getpid()}: Coordinator: Error loading results from worker {worker_id} file: {e}") | |
return None # Return None if any error occurs during loading | |
print(f"Process {os.getpid()}: Coordinator: All {num_workers} worker results loaded successfully.") | |
return worker_results | |
# --- Decorators --- | |
# Removed the log_inference_timing decorator as requested. | |
def update_worker_inference_state_and_save_result(func): | |
""" | |
A decorator that updates the worker's state file before and after | |
the inference function runs, and saves the result to file. | |
Only applies to workers (partition_id > 1). | |
""" | |
@functools.wraps(func) # Use functools.wraps to preserve function metadata | |
def wrapper(*args, **kwargs): | |
# Assuming partition_id is the second argument | |
partition_id = args[1] if len(args) > 1 else 1 # Default to 1 if not provided | |
worker_id = partition_id # Worker ID is the same as partition ID | |
# Only update state for worker partitions (ID > 1) | |
if partition_id > 1: | |
save_state_to_file(worker_id, STATE_STARTED) | |
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_STARTED}") | |
try: | |
result = func(*args, **kwargs) | |
# Only update state and save result for worker partitions (ID > 1) | |
if partition_id > 1: | |
save_state_to_file(worker_id, STATE_INFERENCE_COMPLETE) | |
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_INFERENCE_COMPLETE}") | |
# Save the result to the worker's result file | |
save_result_to_file(worker_id, result) | |
# Update state: Results Saved (after saving the file) | |
save_state_to_file(worker_id, STATE_RESULTS_SAVED) | |
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_RESULTS_SAVED}") | |
return result | |
except Exception as e: | |
# If an error occurs and it's a worker partition, update state to FAILED | |
if partition_id > 1: | |
save_state_to_file(worker_id, STATE_FAILED) | |
print(f"Process {os.getpid()}: Worker {worker_id}: Updated state to {STATE_FAILED} due to error.") | |
raise # Re-raise the exception | |
return wrapper | |
# --- Core Pipeline Functions --- | |
# Removed @log_inference_timing decorator | |
@update_worker_inference_state_and_save_result # Apply the state update and save decorator | |
def inference(data_partition: pd.DataFrame, partition_id: int, simulate_failure: bool = False) -> pd.DataFrame: | |
""" | |
Simulates the inference step of the ML pipeline. | |
Replace with your actual model inference logic. | |
Added simulate_failure for testing timeout/failure handling. | |
""" | |
# The logging, timing (removed), state updates, and result saving for workers | |
# are now handled by the decorators | |
# Simulate some work | |
sleep_duration = 3 + partition_id # Simulate different inference times | |
if simulate_failure and partition_id > 1: # Simulate failure only for workers (partition_id > 1) | |
time.sleep(sleep_duration / 2) # Fail halfway through | |
raise RuntimeError(f"Simulated failure in partition {partition_id} inference") | |
time.sleep(sleep_duration) | |
# Simulate generating some result data | |
result = data_partition.copy() | |
# Ensure column name is unique to the partition for clarity | |
result[f'inference_output_{partition_id}'] = result['value'] * 10 | |
# Add partition ID to the result DataFrame here, as it's part of the result | |
# and will be saved by the decorator for workers. | |
result['partition_id'] = partition_id | |
return result | |
def postprocess(combined_results: pd.DataFrame): | |
""" | |
Simulates the postprocessing step. | |
Replace with your actual postprocessing logic. | |
""" | |
print(f"Process {os.getpid()}: Starting postprocessing...") | |
# Simulate some work | |
time.sleep(5) | |
# Simulate processing the combined results | |
# Ensure columns exist before accessing - need to check for all worker outputs | |
# Assuming 'partition_id' column exists in combined_results | |
if 'partition_id' in combined_results.columns: | |
output_cols = [f'inference_output_{i}' for i in combined_results['partition_id'].unique() if f'inference_output_{i}' in combined_results.columns] | |
if len(output_cols) > 0: | |
# Example: Summing up all inference outputs per row | |
combined_results['final_output'] = combined_results[output_cols].sum(axis=1) | |
else: | |
print(f"Process {os.getpid()}: Warning: Could not find any inference output columns for postprocessing.") | |
combined_results['final_output'] = pd.NA # Or some other default | |
else: | |
print(f"Process {os.getpid()}: Warning: 'partition_id' column not found in combined_results. Cannot perform postprocessing based on partition outputs.") | |
combined_results['final_output'] = pd.NA # Or some other default | |
print(f"Process {os.getpid()}: Postprocessing finished.") | |
print("\n--- Final Results ---") | |
print(combined_results) | |
# --- Combined Pipeline Job Function --- | |
def run_partition_job(data_partition: pd.DataFrame, partition_id: int, simulate_failure: bool = False): | |
""" | |
Runs the pipeline for a specific data partition based on its ID. | |
Partition ID 1 is the coordinator. Partition IDs > 1 are workers. | |
Uses files for state and result exchange. | |
""" | |
role = 'coordinator' if partition_id == 1 else 'worker' | |
print(f"Process {os.getpid()}: Running as {role} with Partition ID {partition_id}...") | |
# Run inference for this partition. | |
# The decorators handle state updates and result saving for workers. | |
try: | |
inference_result = inference(data_partition, partition_id, simulate_failure=simulate_failure) | |
# The 'partition_id' is added to the result DataFrame inside the inference function. | |
except Exception as e: | |
print(f"Process {os.getpid()}: Process {os.getpid()} (Partition {partition_id}): Error during inference: {e}") | |
sys.exit(1) # Exit the process if inference fails | |
if role == 'coordinator': | |
print(f"Process {os.getpid()}: Coordinator: Starting coordinator tasks...") | |
# 1. Coordinator's inference result is already computed and stored in inference_result | |
# 2. Wait for ALL workers' results to be available via state files with a timeout | |
# The decorator handles STATE_STARTED, STATE_INFERENCE_COMPLETE, STATE_FAILED, | |
# and STATE_RESULTS_SAVED for workers. We wait for STATE_RESULTS_SAVED here. | |
all_workers_successful = wait_for_all_workers_completion(NUM_WORKERS) | |
if not all_workers_successful: | |
print(f"Process {os.getpid()}: Coordinator: Not all workers completed successfully. Cannot proceed with postprocessing.") | |
# Handle the timeout/failure scenario - e.g., log error, send alert, process only coordinator data, etc. | |
sys.exit(1) # Exit if workers did not complete | |
# 3. Fetch ALL workers' results from their result files using the new function | |
worker_results = load_all_worker_results(NUM_WORKERS) | |
if worker_results is None: # load_all_worker_results returns None on failure | |
print(f"Process {os.getpid()}: Coordinator: Failed to load results from all workers. Cannot proceed with postprocessing.") | |
sys.exit(1) | |
# 4. Combine coordinator and ALL worker results | |
print(f"Process {os.getpid()}: Coordinator: Combining results from coordinator and {NUM_WORKERS} workers...") | |
# The coordinator's result is in inference_result | |
all_results = [inference_result] + worker_results | |
# Ensure columns are distinguishable if needed before concat (already handled in inference) | |
# Add a dummy id for demonstration if not already present | |
# This assumes original data had a consistent index or ID | |
# For simple concat, let's rely on the original index or add a new one | |
# Let's add a global index for clarity after combining | |
combined_results = pd.concat(all_results, ignore_index=True) | |
print(f"Process {os.getpid()}: Coordinator: Results combined.") | |
# 5. Run postprocessing on the combined results | |
postprocess(combined_results) | |
print(f"Process {os.getpid()}: Coordinator: Coordinator job finished successfully.") | |
elif role == 'worker': | |
worker_id = partition_id # Worker ID is the same as its partition ID | |
print(f"Process {os.getpid()}: Worker {worker_id}: Starting worker tasks...") | |
# The inference call above, with its decorators, handled the worker's tasks: | |
# running inference, updating state, saving result, and setting STATE_RESULTS_SAVED. | |
# No further steps are needed in the worker block after successful inference. | |
# If inference failed, the try/except block above would have caught it and exited. | |
# If inference was successful, the decorator saved the result and updated state. | |
print(f"Process {os.getpid()}: Worker {worker_id}: Worker job finished successfully.") | |
else: | |
print(f"Process {os.getpid()}: Invalid role/partition ID: {partition_id}. Must be 1 (coordinator) or > 1 (worker).") | |
sys.exit(1) | |
# --- Data Partitioning Helper --- | |
def partition_data(df: pd.DataFrame, num_partitions: int) -> list[pd.DataFrame]: | |
"""Partitions the input DataFrame into the specified number of parts.""" | |
partition_size = len(df) // num_partitions | |
partitions = [df.iloc[i * partition_size:(i + 1) * partition_size].copy() for i in range(num_partitions)] | |
# Ensure all data is included even if not perfectly divisible | |
if len(df) % num_partitions != 0: | |
partitions[-1] = pd.concat([partitions[-1], df.iloc[(num_partitions - 1) * partition_size + len(partitions[-1]):].copy()], ignore_index=True) | |
return partitions | |
# --- Main Execution Block (for local testing) --- | |
# To run on separate nodes, you would launch this script on each node | |
# with the appropriate command-line arguments ('coordinator' or 'worker <partition_id>') | |
# and ensure the data partition is available to that script instance. | |
if __name__ == "__main__": | |
print("--- Distributed Pipeline Simulation (Local Test using Files) ---") | |
print("To run on separate nodes, ensure the state and result files are accessible") | |
print("by both nodes (e.g., on a shared network drive or cloud storage).") | |
print("Run this script on one node with 'python your_script_name.py coordinator'") | |
print("and on other nodes with 'python your_script_name.py worker <partition_id>'") | |
print("(where <partition_id> is 2, 3, ..., NUM_WORKERS + 1).") | |
print("-" * 60) | |
# --- Prepare Dummy Data --- | |
# Create a dummy dataset and split it into NUM_WORKERS + 1 partitions | |
data = {'value': range(100)} # Increased data size for more partitions | |
df = pd.DataFrame(data) | |
total_partitions = NUM_WORKERS + 1 | |
data_partitions = partition_data(df, total_partitions) | |
print(f"Main: Data partitioned into {total_partitions} parts.") | |
for i, part in enumerate(data_partitions): | |
print(f"Main: Partition {i+1} size: {len(part)}") | |
# Clean up previous worker state and result files before starting | |
clean_up_worker_files(NUM_WORKERS) | |
print("\n--- Launching Coordinator and Worker Processes (Local Simulation) ---") | |
print("This simulates running them on different nodes by using multiprocessing.") | |
print("In a real distributed setup, you would launch this script separately") | |
print("on your different machines with the appropriate arguments.") | |
# --- Command Line Arguments to Determine Role and Partition --- | |
# This allows running the same script with different arguments | |
# to act as coordinator or a specific worker. | |
# Usage: | |
# python combined_pipeline.py coordinator | |
# python combined_pipeline.py worker 2 # For worker with partition ID 2 | |
# python combined_pipeline.py worker 3 # For worker with partition ID 3, etc. | |
# python combined_pipeline.py local_test # For local simulation | |
if len(sys.argv) > 1: | |
role = sys.argv[1].lower() | |
if role == 'coordinator': | |
if len(data_partitions) < 1: | |
print("Error: Not enough data partitions for coordinator.") | |
sys.exit(1) | |
coordinator_partition_data = data_partitions[0] # Partition 1 data | |
# Run the combined pipeline job function as the coordinator | |
# In a real multi-node setup, you would run this directly | |
# without multiprocessing.Process if this script is the entry point. | |
# For local simulation, we use Process to mimic separation. | |
# run_partition_job(coordinator_partition_data, 1, simulate_failure=False) # For direct execution | |
coordinator_process = multiprocessing.Process(target=run_partition_job, args=(coordinator_partition_data, 1, False)) | |
coordinator_process.start() | |
coordinator_process.join() | |
elif role == 'worker': | |
if len(sys.argv) > 2: | |
try: | |
partition_id = int(sys.argv[2]) | |
if partition_id < 2 or partition_id > NUM_WORKERS + 1: | |
print(f"Invalid worker partition ID: {partition_id}. Must be between 2 and {NUM_WORKERS + 1}.") | |
sys.exit(1) | |
if partition_id > len(data_partitions): | |
print(f"Error: Data partition {partition_id} does not exist.") | |
sys.exit(1) | |
worker_partition_data = data_partitions[partition_id - 1] # Get the correct partition data (0-indexed) | |
# Set simulate_failure=True here for a specific worker ID to test failure | |
# simulate_failure = (partition_id == 2) # Example: Simulate failure for worker 2 | |
simulate_failure = False # Default: no failure simulation | |
# Run the combined pipeline job function as a worker | |
# In a real multi-node setup, you would run this directly | |
# without multiprocessing.Process if this script is the entry point. | |
# For local simulation, we use Process to mimic separation. | |
# run_partition_job(worker_partition_data, partition_id, simulate_failure=simulate_failure) # For direct execution | |
worker_process = multiprocessing.Process(target=run_partition_job, args=(worker_partition_data, partition_id, simulate_failure)) | |
worker_process.start() | |
worker_process.join() | |
except ValueError: | |
print("Invalid worker partition ID. Must be an integer.") | |
sys.exit(1) | |
else: | |
print("Please specify the worker partition ID (e.g., 'python combined_pipeline.py worker 2').") | |
sys.exit(1) | |
elif role == 'local_test': | |
# This option runs the coordinator and all workers locally using multiprocessing | |
print("\n--- Running Local Test Simulation ---") | |
# Launch the coordinator process | |
coordinator_process = multiprocessing.Process(target=run_partition_job, args=(data_partitions[0], 1, False)) # Coordinator is partition 1 | |
coordinator_process.start() | |
# Launch worker processes | |
worker_processes = [] | |
# Worker IDs are from 2 to NUM_WORKERS + 1 | |
for i in range(NUM_WORKERS): | |
worker_id = i + 2 # Worker IDs start from 2 | |
worker_partition_data = data_partitions[i + 1] # Workers get partitions from index 1 onwards | |
# Set simulate_failure=True for a specific worker ID to test failure | |
# simulate_failure = (worker_id == 2) # Example: Simulate failure for worker 2 | |
simulate_failure = False # Default: no failure simulation | |
worker_process = multiprocessing.Process(target=run_partition_job, args=(worker_partition_data, worker_id, simulate_failure)) | |
worker_processes.append(worker_process) | |
worker_process.start() | |
# Wait for Processes to Finish | |
coordinator_process.join() | |
for p in worker_processes: | |
p.join() | |
print("\nLocal test simulation finished.") | |
else: | |
print("Invalid role specified. Use 'coordinator', 'worker <partition_id>', or 'local_test'.") | |
sys.exit(1) | |
else: | |
print("Please specify the role to run: 'coordinator', 'worker <partition_id>', or 'local_test'.") | |
print("Example: python combined_pipeline.py coordinator") | |
print("Example: python combined_pipeline.py worker 2") | |
print("Example: python combined_pipeline.py local_test") | |
sys.exit(1) | |
print("Main: Script finished.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment