Last active
May 25, 2019 08:36
-
-
Save graipher/bc94156fe6e740a55b49dcc1e631c027 to your computer and use it in GitHub Desktop.
Timing measurements for https://stackoverflow.com/q/56288015/4042267
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from functools import partial | |
import timeit | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from itertools import starmap, count | |
from string import digits | |
from random import choices | |
def get_time(func, x): | |
timer = timeit.Timer(partial(func, x)) | |
t = timer.repeat(repeat=5, number=1) | |
return np.min(t), np.std(t) / np.sqrt(len(t)) | |
def get_times(func, inputs, star=False): | |
if star: | |
return np.array(list(starmap(partial(get_time, func), inputs))) | |
return np.array(list(map(partial(get_time, func), inputs))) | |
def get_df(funcs, inputs, key, star=False): | |
df = pd.DataFrame(list(map(key, inputs)), columns=["x"]) | |
for i, func in enumerate(funcs): | |
label = str(i) if func.__name__ == "<lambda>" else func.__name__ | |
df[label], df[label + "_err"] = get_times(func, inputs, star=star).T | |
return df | |
def counter(): | |
c = count() | |
def wrapper(*args): | |
return next(c) | |
return wrapper | |
def identity(x): | |
return x | |
def plot_times(funcs, inputs, key=identity, xlabel="x", ylabel="Time [s]", logx=False, logy=False, ratio=False, star=False): | |
df = get_df(funcs, inputs, key, star) | |
for label in df.columns[1::2]: | |
x, y, yerr = df["x"], df[label], df[label + "_err"] | |
if ratio: | |
y, yerr = y / df.T.iloc[1], yerr / df.T.iloc[1] | |
plt.errorbar(x, y, yerr, fmt='o-', label=label) | |
plt.xlabel(xlabel) | |
if ratio: | |
ylabel = ylabel + " / " + df.columns[0] | |
plt.ylabel(ylabel) | |
if logx: | |
plt.xscale("log") | |
if logy: | |
plt.yscale("log") | |
plt.legend() | |
plt.show() | |
def count_even_digits_spyr03_for(n): | |
count = 0 | |
for c in str(n): | |
if c in "02468": | |
count += 1 | |
return count | |
def count_even_digits_spyr03_sum(n): | |
return sum(c in "02468" for c in str(n)) | |
def count_even_digits_spyr03_sum2(n): | |
return sum(1 for c in str(n) if c in "02468") | |
def count_even_digits_spyr03_count_unrolled(n): | |
s = str(n) | |
return s.count("0") + s.count("2") + s.count("4") + s.count("6") + s.count("8") | |
if __name__ == "__main__": | |
x = [int("".join(choices(digits, k=n))) for n in np.logspace(1, 5, dtype=int)] | |
funcs = [count_even_digits_spyr03_for, count_even_digits_spyr03_sum, | |
count_even_digits_spyr03_sum2, count_even_digits_spyr03_count_unrolled] | |
plot_times(funcs, x, xlabel="$\log_{10} n$", logx=True, ratio=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment