Created
June 11, 2025 19:28
-
-
Save ArturNiederfahrenhorst/5ded71ebb5ac28d24d1d63c37c4600f2 to your computer and use it in GitHub Desktop.
Benchmark percentile calculation in RLlib's Stats objects
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import random | |
| import heapq | |
| import numpy as np | |
| from typing import List | |
| from ray.rllib.utils.metrics.stats import Stats | |
| def generate_random_values(n: int) -> List[float]: | |
| """Generate n random float values for testing.""" | |
| return [random.uniform(0, 1000) for _ in range(n)] | |
| def benchmark_sorting_during_peek(n_values: int, num_runs: int = 5) -> dict: | |
| """ | |
| Benchmark sorting performance during peek/reduce operations. | |
| This tests the sorting in _reduced_values() method. | |
| """ | |
| print(f"\n📊 Benchmarking sorting during peek() with {n_values:,} values...") | |
| results = { | |
| 'peek_compile_false': [], | |
| 'peek_compile_true': [], | |
| 'reduce_compile_false': [], | |
| 'reduce_compile_true': [] | |
| } | |
| for run in range(num_runs): | |
| print(f" Run {run + 1}/{num_runs}") | |
| # Create Stats object with percentiles enabled | |
| # Use a window size equal to n_values to hold all values | |
| stats = Stats(reduce=None, percentiles=True, window=n_values) | |
| # Generate and push random values | |
| values = generate_random_values(n_values) | |
| for value in values: | |
| stats.push(value) | |
| values_copy = values.copy() | |
| # Benchmark peek(compile=False) - returns sorted list | |
| start = time.perf_counter() | |
| sorted_values = stats.peek(compile=False) | |
| duration = time.perf_counter() - start | |
| results['peek_compile_false'].append(duration) | |
| # Verify it's actually sorted | |
| assert sorted_values == sorted(values), "Values should be sorted" | |
| # Benchmark peek(compile=True) - computes percentiles | |
| stats.values = values_copy.copy() | |
| start = time.perf_counter() | |
| percentiles = stats.peek(compile=True) | |
| duration = time.perf_counter() - start | |
| results['peek_compile_true'].append(duration) | |
| # Verify percentiles structure | |
| assert isinstance(percentiles, dict), "Should return percentiles dict" | |
| expected_keys = [0, 50, 75, 90, 95, 99, 100] | |
| assert set(percentiles.keys()) == set(expected_keys), "Should have default percentiles" | |
| # Benchmark reduce(compile=False) - returns Stats object with sorted list | |
| stats.values = values_copy.copy() | |
| start = time.perf_counter() | |
| reduced_stats = stats.reduce(compile=False) | |
| duration = time.perf_counter() - start | |
| results['reduce_compile_false'].append(duration) | |
| # Benchmark reduce(compile=True) - computes percentiles | |
| stats.values = values_copy.copy() | |
| start = time.perf_counter() | |
| reduced_percentiles = stats.reduce(compile=True) | |
| duration = time.perf_counter() - start | |
| results['reduce_compile_true'].append(duration) | |
| # Calculate averages | |
| avg_results = {} | |
| for key, times in results.items(): | |
| avg_results[key] = { | |
| 'avg_time': np.mean(times), | |
| 'std_time': np.std(times), | |
| 'min_time': np.min(times), | |
| 'max_time': np.max(times) | |
| } | |
| return avg_results | |
| def benchmark_merging_performance(n_values_per_stats: int, num_stats: int = 10, num_runs: int = 3) -> dict: | |
| """ | |
| Benchmark merging performance when combining multiple Stats objects. | |
| This tests the heapq.merge() performance in merge_in_parallel() method. | |
| """ | |
| print(f"\n🔀 Benchmarking merging {num_stats} Stats objects with {n_values_per_stats:,} values each...") | |
| results = { | |
| 'merge_time': [], | |
| 'total_values': n_values_per_stats * num_stats | |
| } | |
| for run in range(num_runs): | |
| print(f" Run {run + 1}/{num_runs}") | |
| # Create multiple Stats objects with random values | |
| stats_objects = [] | |
| for i in range(num_stats): | |
| stats = Stats(reduce=None, percentiles=True, window=n_values_per_stats * num_stats) | |
| values = generate_random_values(n_values_per_stats) | |
| # Add values in sorted order to simulate real percentile use case | |
| # where individual Stats maintain sorted order | |
| values.sort() | |
| for value in values: | |
| stats.push(value) | |
| stats_objects.append(stats) | |
| # Create a base Stats object | |
| base_stats = Stats(reduce=None, percentiles=True, window=n_values_per_stats * num_stats) | |
| base_values = generate_random_values(n_values_per_stats) | |
| base_values.sort() | |
| for value in base_values: | |
| base_stats.push(value) | |
| # Benchmark the merging operation | |
| start = time.perf_counter() | |
| base_stats.merge_in_parallel(*stats_objects) | |
| duration = time.perf_counter() - start | |
| results['merge_time'].append(duration) | |
| # Verify the result is still sorted | |
| merged_values = base_stats.peek(compile=False) | |
| assert merged_values == sorted(merged_values), "Merged values should be sorted" | |
| assert len(merged_values) == results['total_values'], f"Should have {results['total_values']} values" | |
| # Calculate averages | |
| avg_results = { | |
| 'avg_time': np.mean(results['merge_time']), | |
| 'std_time': np.std(results['merge_time']), | |
| 'min_time': np.min(results['merge_time']), | |
| 'max_time': np.max(results['merge_time']), | |
| 'total_values': results['total_values'] | |
| } | |
| return avg_results | |
| def benchmark_pure_sorting_comparison(n_values: int, num_runs: int = 5) -> dict: | |
| """ | |
| Benchmark pure Python sorting for comparison with Stats sorting. | |
| """ | |
| print(f"\n🔧 Benchmarking pure Python sorting with {n_values:,} values for comparison...") | |
| results = { | |
| 'list_sort': [], | |
| 'sorted_builtin': [], | |
| 'numpy_sort': [] | |
| } | |
| for run in range(num_runs): | |
| values = generate_random_values(n_values) | |
| # Test list.sort() | |
| values_copy = values.copy() | |
| start = time.perf_counter() | |
| values_copy.sort() | |
| duration = time.perf_counter() - start | |
| results['list_sort'].append(duration) | |
| # Test sorted() builtin | |
| start = time.perf_counter() | |
| sorted_values = sorted(values) | |
| duration = time.perf_counter() - start | |
| results['sorted_builtin'].append(duration) | |
| # Test numpy sort | |
| np_values = np.array(values) | |
| start = time.perf_counter() | |
| np_sorted = np.sort(np_values) | |
| duration = time.perf_counter() - start | |
| results['numpy_sort'].append(duration) | |
| # Calculate averages | |
| avg_results = {} | |
| for key, times in results.items(): | |
| avg_results[key] = { | |
| 'avg_time': np.mean(times), | |
| 'std_time': np.std(times), | |
| 'min_time': np.min(times), | |
| 'max_time': np.max(times) | |
| } | |
| return avg_results | |
| def benchmark_heapq_merge_comparison(n_values_per_list: int, num_lists: int = 10, num_runs: int = 3) -> dict: | |
| """ | |
| Benchmark pure heapq.merge for comparison with Stats merging. | |
| """ | |
| print(f"\n🔧 Benchmarking pure heapq.merge with {num_lists} lists of {n_values_per_list:,} values each...") | |
| results = {'heapq_merge': []} | |
| for run in range(num_runs): | |
| # Create sorted lists | |
| sorted_lists = [] | |
| for _ in range(num_lists): | |
| values = generate_random_values(n_values_per_list) | |
| values.sort() | |
| sorted_lists.append(values) | |
| # Benchmark heapq.merge | |
| start = time.perf_counter() | |
| merged = list(heapq.merge(*sorted_lists)) | |
| duration = time.perf_counter() - start | |
| results['heapq_merge'].append(duration) | |
| avg_results = { | |
| 'avg_time': np.mean(results['heapq_merge']), | |
| 'std_time': np.std(results['heapq_merge']), | |
| 'min_time': np.min(results['heapq_merge']), | |
| 'max_time': np.max(results['heapq_merge']), | |
| 'total_values': n_values_per_list * num_lists | |
| } | |
| return avg_results | |
| def format_time(seconds: float) -> str: | |
| """Format time in a human-readable way.""" | |
| if seconds < 0.001: | |
| return f"{seconds * 1000000:.1f} μs" | |
| elif seconds < 1.0: | |
| return f"{seconds * 1000:.1f} ms" | |
| else: | |
| return f"{seconds:.3f} s" | |
| def print_results(test_name: str, results: dict, n_values: int): | |
| """Print benchmark results in a formatted way.""" | |
| print(f"\n{'='*50}") | |
| print(f"📊 {test_name} Results ({n_values:,} values)") | |
| print(f"{'='*50}") | |
| if 'avg_time' in results: | |
| # Single test results | |
| print(f"Average time: {format_time(results['avg_time'])}") | |
| print(f"Std deviation: {format_time(results['std_time'])}") | |
| print(f"Min time: {format_time(results['min_time'])}") | |
| print(f"Max time: {format_time(results['max_time'])}") | |
| if 'total_values' in results: | |
| print(f"Total values: {results['total_values']:,}") | |
| else: | |
| # Multiple test results | |
| for test_type, test_results in results.items(): | |
| print(f"\n{test_type}:") | |
| print(f" Average: {format_time(test_results['avg_time'])}") | |
| print(f" Std dev: {format_time(test_results['std_time'])}") | |
| print(f" Range: {format_time(test_results['min_time'])} - {format_time(test_results['max_time'])}") | |
| print("RLlib Stats Sorting Performance Benchmark") | |
| print("=" * 60) | |
| # Test configurations | |
| test_sizes = [1000, 10000, 1000000] # 1K, 10K, 1M values | |
| print(f"Testing with {len(test_sizes)} different data sizes:") | |
| for size in test_sizes: | |
| print(f" - {size:,} values") | |
| print(f"\nRandom seed: {random.seed(42)}") # For reproducible results | |
| random.seed(42) | |
| np.random.seed(42) | |
| all_results = {} | |
| for n_values in test_sizes: | |
| print(f"\n{'🎯 TESTING WITH ' + f'{n_values:,} VALUES':=^80}") | |
| # 1. Benchmark sorting during peek/reduce operations | |
| sorting_results = benchmark_sorting_during_peek(n_values, num_runs=3 if n_values >= 100000 else 5) | |
| all_results[f'sorting_{n_values}'] = sorting_results | |
| print_results(f"Stats Sorting Operations", sorting_results, n_values) | |
| # 2. Benchmark merging performance | |
| # For 1M values, use fewer stats objects to keep test reasonable | |
| num_stats_objects = 5 if n_values >= 100000 else 10 | |
| values_per_stats = max(100, n_values // num_stats_objects) | |
| merging_results = benchmark_merging_performance( | |
| values_per_stats, | |
| num_stats_objects, | |
| num_runs=2 if n_values >= 100000 else 3 | |
| ) | |
| all_results[f'merging_{n_values}'] = merging_results | |
| print_results(f"Stats Merging Operations", merging_results, merging_results['total_values']) | |
| # 3. Benchmark pure sorting for comparison | |
| pure_sorting_results = benchmark_pure_sorting_comparison(n_values, num_runs=3 if n_values >= 100000 else 5) | |
| all_results[f'pure_sorting_{n_values}'] = pure_sorting_results | |
| print_results(f"Pure Python Sorting Comparison", pure_sorting_results, n_values) | |
| # 4. Benchmark pure heapq.merge for comparison | |
| heapq_results = benchmark_heapq_merge_comparison( | |
| values_per_stats, | |
| num_stats_objects, | |
| num_runs=2 if n_values >= 100000 else 3 | |
| ) | |
| all_results[f'heapq_{n_values}'] = heapq_results | |
| print_results(f"Pure heapq.merge Comparison", heapq_results, heapq_results['total_values']) | |
| # Summary | |
| print(f"\n{'📊 PERFORMANCE SUMMARY':=^80}") | |
| print("\nSorting Performance (average times):") | |
| print("-" * 40) | |
| for n_values in test_sizes: | |
| sorting_res = all_results[f'sorting_{n_values}'] | |
| pure_res = all_results[f'pure_sorting_{n_values}'] | |
| print(f"\n{n_values:,} values:") | |
| print(f" Stats peek(compile=False): {format_time(sorting_res['peek_compile_false']['avg_time'])}") | |
| print(f" Stats peek(compile=True): {format_time(sorting_res['peek_compile_true']['avg_time'])}") | |
| print(f" Pure list.sort(): {format_time(pure_res['list_sort']['avg_time'])}") | |
| print(f" Pure sorted(): {format_time(pure_res['sorted_builtin']['avg_time'])}") | |
| print(f" NumPy sort: {format_time(pure_res['numpy_sort']['avg_time'])}") | |
| print("\nMerging Performance (average times):") | |
| print("-" * 40) | |
| for n_values in test_sizes: | |
| merging_res = all_results[f'merging_{n_values}'] | |
| heapq_res = all_results[f'heapq_{n_values}'] | |
| print(f"\n~{n_values:,} total values:") | |
| print(f" Stats merge_in_parallel(): {format_time(merging_res['avg_time'])}") | |
| print(f" Pure heapq.merge(): {format_time(heapq_res['avg_time'])}") | |
| print(f" Total values merged: {merging_res['total_values']:,}") | |
| print(f"\n{'BENCHMARK COMPLETE':=^80}") | |
| # Results from running this locally on my M1 MacBook Pro | |
| # RLlib Stats Sorting Performance Benchmark | |
| # ============================================================ | |
| # Testing with 3 different data sizes: | |
| # - 1,000 values | |
| # - 10,000 values | |
| # - 1,000,000 values | |
| # Random seed: None | |
| # ==========================🎯 TESTING WITH 1,000 VALUES=========================== | |
| # 📊 Benchmarking sorting during peek() with 1,000 values... | |
| # Run 1/5 | |
| # Run 2/5 | |
| # Run 3/5 | |
| # Run 4/5 | |
| # Run 5/5 | |
| # ================================================== | |
| # 📊 Stats Sorting Operations Results (1,000 values) | |
| # ================================================== | |
| # peek_compile_false: | |
| # Average: 77.8 μs | |
| # Std dev: 1.6 μs | |
| # Range: 76.1 μs - 80.4 μs | |
| # peek_compile_true: | |
| # Average: 67.2 μs | |
| # Std dev: 2.2 μs | |
| # Range: 64.1 μs - 70.8 μs | |
| # reduce_compile_false: | |
| # Average: 84.8 μs | |
| # Std dev: 5.1 μs | |
| # Range: 79.2 μs - 93.8 μs | |
| # reduce_compile_true: | |
| # Average: 66.1 μs | |
| # Std dev: 4.5 μs | |
| # Range: 59.0 μs - 71.8 μs | |
| # 🔀 Benchmarking merging 10 Stats objects with 100 values each... | |
| # Run 1/3 | |
| # Run 2/3 | |
| # Run 3/3 | |
| # ================================================== | |
| # 📊 Stats Merging Operations Results (1,000 values) | |
| # ================================================== | |
| # Average time: 229.5 μs | |
| # Std deviation: 19.5 μs | |
| # Min time: 209.5 μs | |
| # Max time: 256.0 μs | |
| # Total values: 1,000 | |
| # 🔧 Benchmarking pure Python sorting with 1,000 values for comparison... | |
| # ================================================== | |
| # 📊 Pure Python Sorting Comparison Results (1,000 values) | |
| # ================================================== | |
| # list_sort: | |
| # Average: 72.8 μs | |
| # Std dev: 2.1 μs | |
| # Range: 69.3 μs - 75.0 μs | |
| # sorted_builtin: | |
| # Average: 69.4 μs | |
| # Std dev: 2.1 μs | |
| # Range: 66.2 μs - 71.8 μs | |
| # numpy_sort: | |
| # Average: 43.3 μs | |
| # Std dev: 5.5 μs | |
| # Range: 39.8 μs - 54.2 μs | |
| # 🔧 Benchmarking pure heapq.merge with 10 lists of 100 values each... | |
| # ================================================== | |
| # 📊 Pure heapq.merge Comparison Results (1,000 values) | |
| # ================================================== | |
| # Average time: 177.5 μs | |
| # Std deviation: 2.1 μs | |
| # Min time: 175.7 μs | |
| # Max time: 180.5 μs | |
| # Total values: 1,000 | |
| # ==========================🎯 TESTING WITH 10,000 VALUES========================== | |
| # 📊 Benchmarking sorting during peek() with 10,000 values... | |
| # Run 1/5 | |
| # Run 2/5 | |
| # Run 3/5 | |
| # Run 4/5 | |
| # Run 5/5 | |
| # ================================================== | |
| # 📊 Stats Sorting Operations Results (10,000 values) | |
| # ================================================== | |
| # peek_compile_false: | |
| # Average: 951.1 μs | |
| # Std dev: 13.3 μs | |
| # Range: 941.1 μs - 977.3 μs | |
| # peek_compile_true: | |
| # Average: 916.2 μs | |
| # Std dev: 8.6 μs | |
| # Range: 907.8 μs - 928.1 μs | |
| # reduce_compile_false: | |
| # Average: 1.1 ms | |
| # Std dev: 25.1 μs | |
| # Range: 1.1 ms - 1.1 ms | |
| # reduce_compile_true: | |
| # Average: 957.7 μs | |
| # Std dev: 19.9 μs | |
| # Range: 938.5 μs - 986.0 μs | |
| # 🔀 Benchmarking merging 10 Stats objects with 1,000 values each... | |
| # Run 1/3 | |
| # Run 2/3 | |
| # Run 3/3 | |
| # ================================================== | |
| # 📊 Stats Merging Operations Results (10,000 values) | |
| # ================================================== | |
| # Average time: 2.1 ms | |
| # Std deviation: 29.5 μs | |
| # Min time: 2.0 ms | |
| # Max time: 2.1 ms | |
| # Total values: 10,000 | |
| # 🔧 Benchmarking pure Python sorting with 10,000 values for comparison... | |
| # ================================================== | |
| # 📊 Pure Python Sorting Comparison Results (10,000 values) | |
| # ================================================== | |
| # list_sort: | |
| # Average: 929.2 μs | |
| # Std dev: 4.5 μs | |
| # Range: 922.8 μs - 935.5 μs | |
| # sorted_builtin: | |
| # Average: 997.1 μs | |
| # Std dev: 38.8 μs | |
| # Range: 921.9 μs - 1.0 ms | |
| # numpy_sort: | |
| # Average: 511.7 μs | |
| # Std dev: 10.8 μs | |
| # Range: 497.0 μs - 526.6 μs | |
| # 🔧 Benchmarking pure heapq.merge with 10 lists of 1,000 values each... | |
| # ================================================== | |
| # 📊 Pure heapq.merge Comparison Results (10,000 values) | |
| # ================================================== | |
| # Average time: 1.8 ms | |
| # Std deviation: 63.7 μs | |
| # Min time: 1.7 ms | |
| # Max time: 1.9 ms | |
| # Total values: 10,000 | |
| # ========================🎯 TESTING WITH 1,000,000 VALUES========================= | |
| # 📊 Benchmarking sorting during peek() with 1,000,000 values... | |
| # Run 1/3 | |
| # Run 2/3 | |
| # Run 3/3 | |
| # ================================================== | |
| # 📊 Stats Sorting Operations Results (1,000,000 values) | |
| # ================================================== | |
| # peek_compile_false: | |
| # Average: 174.7 ms | |
| # Std dev: 5.0 ms | |
| # Range: 168.1 ms - 180.3 ms | |
| # peek_compile_true: | |
| # Average: 157.5 ms | |
| # Std dev: 5.2 ms | |
| # Range: 152.5 ms - 164.6 ms | |
| # reduce_compile_false: | |
| # Average: 213.1 ms | |
| # Std dev: 13.5 ms | |
| # Range: 194.2 ms - 224.8 ms | |
| # reduce_compile_true: | |
| # Average: 176.7 ms | |
| # Std dev: 9.7 ms | |
| # Range: 165.5 ms - 189.2 ms | |
| # 🔀 Benchmarking merging 5 Stats objects with 200,000 values each... | |
| # Run 1/2 | |
| # Run 2/2 | |
| # ================================================== | |
| # 📊 Stats Merging Operations Results (1,000,000 values) | |
| # ================================================== | |
| # Average time: 302.9 ms | |
| # Std deviation: 3.9 ms | |
| # Min time: 299.0 ms | |
| # Max time: 306.8 ms | |
| # Total values: 1,000,000 | |
| # 🔧 Benchmarking pure Python sorting with 1,000,000 values for comparison... | |
| # ================================================== | |
| # 📊 Pure Python Sorting Comparison Results (1,000,000 values) | |
| # ================================================== | |
| # list_sort: | |
| # Average: 158.2 ms | |
| # Std dev: 4.1 ms | |
| # Range: 153.6 ms - 163.7 ms | |
| # sorted_builtin: | |
| # Average: 170.4 ms | |
| # Std dev: 6.7 ms | |
| # Range: 164.2 ms - 179.8 ms | |
| # numpy_sort: | |
| # Average: 77.2 ms | |
| # Std dev: 1.9 ms | |
| # Range: 74.5 ms - 78.7 ms | |
| # 🔧 Benchmarking pure heapq.merge with 5 lists of 200,000 values each... | |
| # ================================================== | |
| # 📊 Pure heapq.merge Comparison Results (1,000,000 values) | |
| # ================================================== | |
| # Average time: 192.1 ms | |
| # Std deviation: 8.6 ms | |
| # Min time: 183.5 ms | |
| # Max time: 200.7 ms | |
| # Total values: 1,000,000 | |
| # =============================📊 PERFORMANCE SUMMARY============================== | |
| # Sorting Performance (average times): | |
| # ---------------------------------------- | |
| # 1,000 values: | |
| # Stats peek(compile=False): 77.8 μs | |
| # Stats peek(compile=True): 67.2 μs | |
| # Pure list.sort(): 72.8 μs | |
| # Pure sorted(): 69.4 μs | |
| # NumPy sort: 43.3 μs | |
| # 10,000 values: | |
| # Stats peek(compile=False): 951.1 μs | |
| # Stats peek(compile=True): 916.2 μs | |
| # Pure list.sort(): 929.2 μs | |
| # Pure sorted(): 997.1 μs | |
| # NumPy sort: 511.7 μs | |
| # 1,000,000 values: | |
| # Stats peek(compile=False): 174.7 ms | |
| # Stats peek(compile=True): 157.5 ms | |
| # Pure list.sort(): 158.2 ms | |
| # Pure sorted(): 170.4 ms | |
| # NumPy sort: 77.2 ms | |
| # Merging Performance (average times): | |
| # ---------------------------------------- | |
| # ~1,000 total values: | |
| # Stats merge_in_parallel(): 229.5 μs | |
| # Pure heapq.merge(): 177.5 μs | |
| # Total values merged: 1,000 | |
| # ~10,000 total values: | |
| # Stats merge_in_parallel(): 2.1 ms | |
| # Pure heapq.merge(): 1.8 ms | |
| # Total values merged: 10,000 | |
| # ~1,000,000 total values: | |
| # Stats merge_in_parallel(): 302.9 ms | |
| # Pure heapq.merge(): 192.1 ms | |
| # Total values merged: 1,000,000 | |
| # ===============================BENCHMARK COMPLETE=============================== |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment