Created
March 6, 2019 18:36
-
-
Save pltrdy/095e9142f17699dde20a0c0eaeb2f33f to your computer and use it in GitHub Desktop.
Simulation pour https://twitter.com/ThomasCabaret84/status/1103324141493600256
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
https://twitter.com/ThomasCabaret84/status/1103324141493600256i | |
""" | |
import torch | |
def run(b, n=100, in_a_row=3): | |
""" | |
Expérience du lancer de pièce. | |
Une pièce parfaitement équilibrée est lancée `n` fois. | |
On ne regarde que les lancers, et tous les lancers, qui sont | |
précédés de `in_a_row` résultats "pile" (valeur 1) consécutifs. | |
Parmi ces lancers, quelle est la proportion de pile? | |
(en moyenne si on répète la même procédure) | |
On execute `b` expériences en parallèle | |
(augmenter b accelère le calcul mais consome plus de mémoire) | |
Args: | |
b(int): nombre d'expérience en parallèle | |
n(int): nombre de lancers (100 par défaut) | |
in_a_row(int): nombre de pile succéssifs avant de considérer | |
un lancer | |
""" | |
t = torch.rand([b, n]).gt(0.5).long() * 2 - 1 | |
print("Initial:") | |
print(t) | |
t_eq_1 = t.eq(1) | |
t_in_a_row = t_eq_1 | |
for i in range(2, in_a_row + 1): | |
t_in_a_row = t_in_a_row[:, :-1] * t_in_a_row[:, 1:] | |
print("----\n%d in a row: " % (i)) | |
print(t_in_a_row) | |
index = t_in_a_row[:, :-1].long() | |
value = t[:, in_a_row:] | |
print("----\nindex: ") | |
print(index) | |
print("----\nvalue: ") | |
print(value) | |
selected_toss = index * value | |
print("----\nselected_toss: ") | |
print(selected_toss) | |
n_toss_p1 = selected_toss.eq(1).long() | |
n_toss_m1 = selected_toss.eq(-1).long() | |
print(n_toss_p1) | |
freq_p1 = (n_toss_p1.sum(1).float()) / selected_toss.ne(0).float().sum(1) | |
print(freq_p1) | |
filter_freq_p1 = freq_p1[1 - torch.isnan(freq_p1)] | |
print(filter_freq_p1) | |
avg_freq_p1 = torch.mean(filter_freq_p1) | |
return avg_freq_p1 | |
def run_many(n_run, b, n, in_a_row): | |
avg_freqs = [] | |
for i in range(n_run): | |
avg_freqs += [run(b, n, in_a_row)] | |
return torch.mean(torch.tensor(avg_freqs)) | |
if __name__ == "__main__": | |
import sys | |
import os | |
_stdout = sys.stdout | |
sys.stdout = open(os.devnull, 'w') | |
r = run_many(100, 10000, 100, 3) | |
sys.stdout = _stdout | |
print(r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment