Created
May 7, 2024 01:53
-
-
Save GeoffChurch/f9a7f41b10ec5fa7916d531e49a54d9f to your computer and use it in GitHub Desktop.
Speculative execution on both branches of a conditional
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Provides speculative_if(cond, branch1, branch2), which runs branch1 followed by | |
branch2 while cond is running. | |
""" | |
from dataclasses import dataclass | |
import itertools | |
import multiprocessing as mp | |
import time | |
import matplotlib.pyplot as plt | |
def enqueue_ret(key, f, q): | |
""" | |
Wraps the function f to return its result through q, | |
keyed by key. | |
""" | |
q.put((key, f())) | |
def get_result(q : mp.Queue, key): | |
""" | |
Blocks until a result tuple with first element equal to key is found in the queue. | |
The elements encountered before the result are put back into the queue at the end, | |
and will be encountered in reverse order. This is an effect of optimizing to | |
avoid any unnecessary copying and/or scanning. A deque could match the asymptotic | |
complexity while only inducing a rotation of the elements, at the cost of overhead. | |
""" | |
backlog = [] | |
while True: | |
key_, val = q.get() | |
if key_ == key: | |
while backlog: | |
q.put(backlog.pop()) | |
return val | |
backlog.append((key_, val)) | |
def speculative_if(cond, branch1, branch2): | |
""" | |
Runs branch1 followed by branch2 while cond is running, | |
and returns the result as soon as it and cond are ready. | |
""" | |
idcond, id1, id2 = range(3) | |
# all results will be put into this queue | |
q = mp.Queue() | |
# start cond_p | |
cond_p = mp.Process(target=enqueue_ret, args=(idcond, cond, q)) | |
cond_p.start() | |
# start branch1_p | |
branch1_p = mp.Process(target=enqueue_ret, args=(id1, branch1, q)) | |
branch1_p.start() | |
# wait for either of cond_p or branch1_p to finish | |
funcname, ret = q.get() | |
if funcname == idcond: | |
assert isinstance(ret, bool) | |
if ret: | |
return get_result(q, id1) | |
else: | |
# The following has another possible implementation, | |
# where we send the termination signal to branch1, | |
# asynchronously start branch2, and then join on both. | |
# That would be better in case branch1 is slow to terminate. | |
# forcibly stop branch1_p | |
branch1_p.terminate() | |
branch1_p.join() | |
# run branch2 in the main process | |
return branch2() | |
else: | |
assert funcname == id1 | |
# start branch2 | |
branch2_p = mp.Process(target=enqueue_ret, args=(id2, branch2, q)) | |
branch2_p.start() | |
# wait for cond_p to finish | |
cond_ret = get_result(q, idcond) | |
assert isinstance(cond_ret, bool) | |
if cond_ret: | |
# forcibly stop branch2_p | |
branch2_p.terminate() | |
branch2_p.join() | |
return ret | |
else: | |
return get_result(q, id2) | |
class SlowF: | |
""" | |
A function object that sleeps for sleep_time seconds and then returns ret. | |
A class is needed because closures are not pickleable. | |
""" | |
def __init__(self, sleep_time, ret): | |
self.sleep_time = sleep_time | |
self.ret = ret | |
def __call__(self): | |
time.sleep(self.sleep_time) | |
return self.ret | |
def get_expected_time(cond_t, t1, t2, which): | |
""" | |
Returns the ideal expected time for the speculative_if function, | |
assuming zero overhead from spawning, switching, and signalling | |
processes. | |
""" | |
if which: | |
return max(cond_t, t1) # cond and 1 run in parallel | |
else: | |
if cond_t <= t1: | |
return cond_t + t2 # 2 starts right after cond | |
else: | |
t2_end = t1 + t2 # 2 starts right after 1 | |
return max(cond_t, t2_end) # but might have to wait for cond | |
@dataclass | |
class TestResult: | |
cond_t: float | |
t1: float | |
t2: float | |
which: bool | |
actual_time: float | |
@property | |
def expected_time(self): | |
return get_expected_time(self.cond_t, self.t1, self.t2, self.which) | |
@property | |
def diff(self): | |
return self.actual_time - self.expected_time | |
def __str__(self): | |
return f"cond_t={self.cond_t}, t1={self.t1}, t2={self.t2}, which={'t1' if self.which else 't2'}, expected_time={self.expected_time:1.2f}, actual_time={self.actual_time:1.2f}, diff={self.diff:1.2f}" | |
def get_test_results(): | |
method_start_time = time.time() | |
rets = [] | |
for cond_t, t1, t2 in itertools.product(range(4), repeat=3): | |
# cond_t, t1, t2 = 3 * cond_t, 3 * t1, 3 * t2 | |
for which in [True, False]: | |
cond = SlowF(cond_t, which) | |
branch1 = SlowF(t1, 1) | |
branch2 = SlowF(t2, 2) | |
start_time = time.time() | |
ret = speculative_if(cond, branch1, branch2) | |
actual_time = time.time() - start_time | |
assert ret == (1 if which else 2) | |
rets.append(TestResult(cond_t, t1, t2, which, actual_time)) | |
print(rets[-1]) | |
print(f"get_test_results time: {time.time() - method_start_time}") | |
return rets | |
def main(): | |
rets = get_test_results() | |
print("Sorted by diff:") | |
increasing_diffs = sorted(rets, key=lambda r: r.diff, reverse=True) | |
for r in increasing_diffs: | |
print(r) | |
expected_times = sorted(set(r.expected_time for r in rets)) | |
plt.plot(expected_times, expected_times, color="red", label="theory") | |
plt.scatter([r.expected_time for r in rets], [r.actual_time for r in rets], marker="x", label="practice") | |
# label the worst offender for each expected time | |
outliers = [] | |
for expected_time in expected_times: | |
same_expected_time = [r for r in rets if r.expected_time == expected_time] | |
outlier = max(same_expected_time, key=lambda r: r.diff) | |
outliers.append(outlier) | |
for r in outliers: | |
plt.text(r.expected_time, r.actual_time, f"cond_t={r.cond_t} t1={r.t1} t2={r.t2} which={'t1' if r.which else 't2'} diff={r.diff:1.2f}") | |
plt.xlabel("Expected time") | |
plt.ylabel("Actual time") | |
plt.legend() | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment