Last active
August 28, 2019 19:31
-
-
Save allanj/451a0e3c8e76556feb18f8a50a9a3af9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, TypeVar, Callable | |
import numpy as np | |
T = TypeVar('T') | |
def bootstrap_paired_ttest(results_a: List[T], | |
results_b: List[T], | |
evaluate_func: Callable[[List[T]], float], | |
sample_times: int = 10000, | |
print_temporary_p_value = False, | |
time_window_to_print_temporary_p_value = 20) -> float: | |
""" | |
Calculate the p-value that the results from system A, is better than the results | |
from system b | |
:param results_a: List of instances/results from system A | |
:param results_b: List of instances/results from system B | |
:param evaluate_func: evaluation function that take in results as input and return a float as evaluation metric | |
:param sample_times: number of times for sampling | |
:param print_temporary_p_value: whether to print the p_value in the middle | |
:param time_window_to_print_temporary_p_value: If printing the p_value, how long does it take to print | |
:return: return the p-value | |
""" | |
total_num = len(results_a) | |
p = 0 | |
for i in range(sample_times): | |
sample_indices = np.random.randint(0, total_num, total_num) | |
a_results = [] | |
b_results = [] | |
for index in sample_indices: | |
a_results.append(results_a[index]) | |
b_results.append(results_b[index]) | |
metric_a = evaluate_func(a_results) | |
metric_b = evaluate_func(b_results) | |
if metric_a - metric_b < 0: | |
p += 1 | |
if print_temporary_p_value and (i+1)%time_window_to_print_temporary_p_value == 0: | |
temporary_p_value = p / (i+1) | |
print(f"p value at {i+1} iteration is {temporary_p_value}") | |
p_val = p / sample_times | |
return p_val | |
""" | |
Example usage: | |
""" | |
results_a = [50, 100, 22,33, 21] | |
results_b = [10, 22, 1,45,20] | |
def evaluate(results:List[int]) -> float: | |
return sum(results)/len(results) | |
bootstrap_paired_ttest(results_a=results_a, results_b=results_b, evaluate_func=evaluate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment