Skip to content

Instantly share code, notes, and snippets.

@cradiator
Created February 11, 2025 01:32
Show Gist options
  • Save cradiator/6159b75e9495a73cffa90b9a81b4eed4 to your computer and use it in GitHub Desktop.
Save cradiator/6159b75e9495a73cffa90b9a81b4eed4 to your computer and use it in GitHub Desktop.
Numba test
import time
import numba
from concurrent.futures import ThreadPoolExecutor
TOTAL_ITER = 100_000_000
def profiling(func):
start_time = time.time()
result = func(0, TOTAL_ITER)
end_time = time.time()
print(f"single thread {func.__name__} time: {end_time - start_time:.4f} result: {result}")
def profiling_multithread(func, thread_cnt: int):
start_time = time.time()
step = TOTAL_ITER // thread_cnt
futures = []
with ThreadPoolExecutor() as executor:
for i in range(0, TOTAL_ITER, step):
futures.append(executor.submit(func, i, i + step))
result = 0.0
for future in futures:
result += future.result()
end_time = time.time()
print(f"multi thread {func.__name__} time: {end_time - start_time:.4f} result: {result}")
# pi = 4(1 − 1/3 + 1/5 − 1/7 + 1/9 - 1/11 + ...)
def calculate_pi(start: int, end: int) -> float:
result = 0.0
positive = True
for i in range(start, end):
tmp = 1.0 / float(i * 2 + 1)
if positive:
result += tmp
else:
result -= tmp
positive = not positive
result = result * 4.0
return result
@numba.jit
def calculate_pi_jit(start: int, end: int) -> float:
result = 0.0
positive = True
for i in range(start, end):
tmp = 1.0 / float(i * 2 + 1)
if positive:
result += tmp
else:
result -= tmp
positive = not positive
result = result * 4.0
return result
@numba.jit(nogil=True)
def calculate_pi_jit_nogil(start: int, end: int) -> float:
result = 0.0
positive = True
for i in range(start, end):
tmp = 1.0 / float(i * 2 + 1)
if positive:
result += tmp
else:
result -= tmp
positive = not positive
result = result * 4.0
return result
def main():
calculate_pi_jit(0, 1)
calculate_pi_jit_nogil(0, 1)
profiling(calculate_pi)
profiling(calculate_pi_jit)
profiling_multithread(calculate_pi, 4)
profiling_multithread(calculate_pi_jit, 4)
profiling_multithread(calculate_pi_jit_nogil, 4)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment