Last active
October 7, 2022 16:25
-
-
Save yknishidate/a85df6901734d397cecdfcccc7308c47 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import random | |
import matplotlib.pyplot as plt | |
class UniformDistribution: | |
def __init__(self, a: float, b: float) -> None: | |
self.a = a | |
self.b = b | |
def sample(self) -> float: | |
return random.uniform(self.a, self.b) | |
def pdf(self, x: float) -> float: | |
return 1.0 / (self.b - self.a) | |
class SinDistribution: | |
def sample(self) -> None: | |
assert False, 'cannot sample from this distribution!' | |
def pdf(self, x: float) -> float: | |
return math.sin(x) | |
def draw_using_weights(candidates: list, weights: list) -> float: | |
weight_sum = sum(weights) | |
cumulation = 0.0 | |
rand = random.uniform(0.0, weight_sum) | |
for i in range(len(weights)): | |
cumulation += weights[i] | |
if rand < cumulation: | |
return candidates[i] | |
def sample_using_sir(num_candidates: int) -> float: | |
source = UniformDistribution(0.0, math.pi) | |
target = SinDistribution() | |
candidates = [] | |
weights = [] | |
for _ in range(num_candidates): | |
x = source.sample() | |
w = target.pdf(x) / source.pdf(x) | |
candidates.append(x) | |
weights.append(w) | |
y = draw_using_weights(candidates, weights) | |
return y | |
if __name__ == "__main__": | |
num_candidates = 4 | |
samples = [sample_using_sir(num_candidates) for _ in range(50000)] | |
plt.hist(samples, bins=20) | |
plt.title("Sampling Importance Resampling M=" + str(num_candidates)) | |
plt.show() | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
SIRを使ってsin分布からランダムサンプリングした