Created
March 12, 2024 22:44
-
-
Save Nikolaj-K/2d40efe9f4f2dbc77e8a8045ec63a481 to your computer and use it in GitHub Desktop.
Manually sampling from any given 1D histogram
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script discussed in the video: | |
https://youtu.be/ndAmT8CYGDM | |
Links: | |
* https://en.wikipedia.org/wiki/Energy-based_model | |
* https://en.wikipedia.org/wiki/Brownian_motion | |
* https://www.extropic.ai/future | |
""" | |
import random | |
import time | |
def sample_digit_from_time(): | |
t = time.time() | |
_secs, subsecs = str(t).split(".") | |
ms_digit = int(subsecs[2]) | |
bin_digit = ms_digit % 2 | |
print(f"t={t}, ms_digit={ms_digit}, bin_digit={bin_digit}") | |
return bin_digit | |
def sample_digit_list(num_digits, use_random_lib): | |
bin_digits = [] | |
for _ in range(num_digits): | |
if use_random_lib: | |
bin_digits.append(random.choice([0, 1])) | |
else: | |
bin_digits.append(sample_digit_from_time()) | |
print(f"(digit sum = {sum(bin_digits)} of {num_digits})") | |
input() | |
return bin_digits | |
def unit_interval_real(bin_digits): | |
x = 0 | |
for n, digit in enumerate(bin_digits): | |
power_of_two = 2 ** (n + 1) | |
x_summand = digit / power_of_two | |
x += x_summand | |
print(f"Level {n}, power={power_of_two}, digit={digit}: add {x_summand}, get x={x}") | |
return x | |
def draw_sample(hist, x): | |
cdf = [sum(hist[:n+1]) / sum(hist) for n in range(len(hist))] # Normalized cdf | |
lb = 0 # lower bound | |
for n, ub in enumerate(cdf): # loop over upper bounds | |
if lb < x < ub: | |
return n # r was in interval with index n | |
lb = ub | |
if __name__=='__main__': | |
HIST = [17, 3, 4, 20, 1, 13, 2, 30] # Sample from this distribution (0 and 3 nd 7 will come up often) | |
bin_digits_list = [sample_digit_list(15, False)] | |
#bin_digits_list = [sample_digit_list(10, True) for _ in range(20)] | |
drawn_samples = [] | |
for bin_digits in bin_digits_list: | |
x = unit_interval_real(bin_digits) | |
print(f"\nRun bin_digits={bin_digits}\nx={x}\n\n" + \ | |
"Sample from histogram:\nindex->weight (chance)\n" + \ | |
"\n".join(f"{idx}->{h} ({p} %)" | |
for idx, (h, p) in enumerate([(h, round(100 * h / sum(HIST), 1)) for h in HIST])) + "\n") | |
drawn_samples.append(draw_sample(HIST, x)) | |
print(f"\nDrew {drawn_samples}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment