Last active
June 4, 2024 12:02
-
-
Save stephenmcd/39ded69946155930c347 to your computer and use it in GitHub Desktop.
Parallel Merge Sort
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import multiprocessing | |
import random | |
import sys | |
import time | |
def merge(*args): | |
# Support explicit left/right args, as well as a two-item | |
# tuple which works more cleanly with multiprocessing. | |
left, right = args[0] if len(args) == 1 else args | |
left_length, right_length = len(left), len(right) | |
left_index, right_index = 0, 0 | |
merged = [] | |
while left_index < left_length and right_index < right_length: | |
if left[left_index] <= right[right_index]: | |
merged.append(left[left_index]) | |
left_index += 1 | |
else: | |
merged.append(right[right_index]) | |
right_index += 1 | |
if left_index == left_length: | |
merged.extend(right[right_index:]) | |
else: | |
merged.extend(left[left_index:]) | |
return merged | |
def merge_sort(data): | |
length = len(data) | |
if length <= 1: | |
return data | |
middle = length / 2 | |
left = merge_sort(data[:middle]) | |
right = merge_sort(data[middle:]) | |
return merge(left, right) | |
def merge_sort_parallel(data): | |
# Creates a pool of worker processes, one per CPU core. | |
# We then split the initial data into partitions, sized | |
# equally per worker, and perform a regular merge sort | |
# across each partition. | |
processes = multiprocessing.cpu_count() | |
pool = multiprocessing.Pool(processes=processes) | |
size = int(math.ceil(float(len(data)) / processes)) | |
data = [data[i * size:(i + 1) * size] for i in range(processes)] | |
data = pool.map(merge_sort, data) | |
# Each partition is now sorted - we now just merge pairs of these | |
# together using the worker pool, until the partitions are reduced | |
# down to a single sorted result. | |
while len(data) > 1: | |
# If the number of partitions remaining is odd, we pop off the | |
# last one and append it back after one iteration of this loop, | |
# since we're only interested in pairs of partitions to merge. | |
extra = data.pop() if len(data) % 2 == 1 else None | |
data = [(data[i], data[i + 1]) for i in range(0, len(data), 2)] | |
data = pool.map(merge, data) + ([extra] if extra else []) | |
return data[0] | |
if __name__ == "__main__": | |
size = int(sys.argv[-1]) if sys.argv[-1].isdigit() else 1000 | |
data_unsorted = [random.randint(0, size) for _ in range(size)] | |
for sort in merge_sort, merge_sort_parallel: | |
start = time.time() | |
data_sorted = sort(data_unsorted) | |
end = time.time() - start | |
print sort.__name__, end, sorted(data_unsorted) == data_sorted |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment