Skip to content

Instantly share code, notes, and snippets.

@pltrdy
Created March 6, 2019 18:36
Show Gist options
  • Save pltrdy/095e9142f17699dde20a0c0eaeb2f33f to your computer and use it in GitHub Desktop.
Save pltrdy/095e9142f17699dde20a0c0eaeb2f33f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
https://twitter.com/ThomasCabaret84/status/1103324141493600256i
"""
import torch
def run(b, n=100, in_a_row=3):
"""
Expérience du lancer de pièce.
Une pièce parfaitement équilibrée est lancée `n` fois.
On ne regarde que les lancers, et tous les lancers, qui sont
précédés de `in_a_row` résultats "pile" (valeur 1) consécutifs.
Parmi ces lancers, quelle est la proportion de pile?
(en moyenne si on répète la même procédure)
On execute `b` expériences en parallèle
(augmenter b accelère le calcul mais consome plus de mémoire)
Args:
b(int): nombre d'expérience en parallèle
n(int): nombre de lancers (100 par défaut)
in_a_row(int): nombre de pile succéssifs avant de considérer
un lancer
"""
t = torch.rand([b, n]).gt(0.5).long() * 2 - 1
print("Initial:")
print(t)
t_eq_1 = t.eq(1)
t_in_a_row = t_eq_1
for i in range(2, in_a_row + 1):
t_in_a_row = t_in_a_row[:, :-1] * t_in_a_row[:, 1:]
print("----\n%d in a row: " % (i))
print(t_in_a_row)
index = t_in_a_row[:, :-1].long()
value = t[:, in_a_row:]
print("----\nindex: ")
print(index)
print("----\nvalue: ")
print(value)
selected_toss = index * value
print("----\nselected_toss: ")
print(selected_toss)
n_toss_p1 = selected_toss.eq(1).long()
n_toss_m1 = selected_toss.eq(-1).long()
print(n_toss_p1)
freq_p1 = (n_toss_p1.sum(1).float()) / selected_toss.ne(0).float().sum(1)
print(freq_p1)
filter_freq_p1 = freq_p1[1 - torch.isnan(freq_p1)]
print(filter_freq_p1)
avg_freq_p1 = torch.mean(filter_freq_p1)
return avg_freq_p1
def run_many(n_run, b, n, in_a_row):
avg_freqs = []
for i in range(n_run):
avg_freqs += [run(b, n, in_a_row)]
return torch.mean(torch.tensor(avg_freqs))
if __name__ == "__main__":
import sys
import os
_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
r = run_many(100, 10000, 100, 3)
sys.stdout = _stdout
print(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment