Skip to content

Instantly share code, notes, and snippets.

@stephenmcd
Last active June 4, 2024 12:02
Show Gist options
  • Save stephenmcd/39ded69946155930c347 to your computer and use it in GitHub Desktop.
Save stephenmcd/39ded69946155930c347 to your computer and use it in GitHub Desktop.
Parallel Merge Sort
import math
import multiprocessing
import random
import sys
import time
def merge(*args):
# Support explicit left/right args, as well as a two-item
# tuple which works more cleanly with multiprocessing.
left, right = args[0] if len(args) == 1 else args
left_length, right_length = len(left), len(right)
left_index, right_index = 0, 0
merged = []
while left_index < left_length and right_index < right_length:
if left[left_index] <= right[right_index]:
merged.append(left[left_index])
left_index += 1
else:
merged.append(right[right_index])
right_index += 1
if left_index == left_length:
merged.extend(right[right_index:])
else:
merged.extend(left[left_index:])
return merged
def merge_sort(data):
length = len(data)
if length <= 1:
return data
middle = length / 2
left = merge_sort(data[:middle])
right = merge_sort(data[middle:])
return merge(left, right)
def merge_sort_parallel(data):
# Creates a pool of worker processes, one per CPU core.
# We then split the initial data into partitions, sized
# equally per worker, and perform a regular merge sort
# across each partition.
processes = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=processes)
size = int(math.ceil(float(len(data)) / processes))
data = [data[i * size:(i + 1) * size] for i in range(processes)]
data = pool.map(merge_sort, data)
# Each partition is now sorted - we now just merge pairs of these
# together using the worker pool, until the partitions are reduced
# down to a single sorted result.
while len(data) > 1:
# If the number of partitions remaining is odd, we pop off the
# last one and append it back after one iteration of this loop,
# since we're only interested in pairs of partitions to merge.
extra = data.pop() if len(data) % 2 == 1 else None
data = [(data[i], data[i + 1]) for i in range(0, len(data), 2)]
data = pool.map(merge, data) + ([extra] if extra else [])
return data[0]
if __name__ == "__main__":
size = int(sys.argv[-1]) if sys.argv[-1].isdigit() else 1000
data_unsorted = [random.randint(0, size) for _ in range(size)]
for sort in merge_sort, merge_sort_parallel:
start = time.time()
data_sorted = sort(data_unsorted)
end = time.time() - start
print sort.__name__, end, sorted(data_unsorted) == data_sorted
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment