Created
September 7, 2018 05:15
-
-
Save mvbattan/96448873772147bb7a20aade09dd19da to your computer and use it in GitHub Desktop.
Ejercicio 3 - Generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Taken from https://stackoverflow.com/a/44308018 | |
from scipy.stats import norm, truncnorm | |
import matplotlib.pyplot as plt | |
from numpy import array | |
from math import pi, cos | |
SAMPLE_SIZE = 200 | |
def get_truncated_normal(mean=0, sd=1, low=0, upp=10): | |
return truncnorm( | |
(low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) | |
def m(x): | |
if -1 <= x < -0.5: | |
return ((x + 2) ** 2 ) / 2.0 | |
if -0.5 <= x <= 0: | |
return x / 2.0 + 0.875 | |
if 0 < x <= 0.5: | |
return -5 * (x - 0.2) ** 2 + 1.075 | |
if 0.5 < x <= 1: | |
return x + 0.125 | |
def get_conditional_sd(x): | |
return 0.2 - 0.1 * cos(2 * pi * x) | |
def get_y_rvs(x): | |
Y_x = norm(0, get_conditional_sd(x)) | |
return Y_x.rvs() | |
def main(): | |
X = get_truncated_normal(mean=0, sd=1, low=-1, upp=1) | |
x_samples = X.rvs(SAMPLE_SIZE).tolist() | |
m_samples = list(map(m, x_samples)) | |
y_x_samples = list(map(get_y_rvs, x_samples)) | |
y_samples = [y_x + m_samples[i] for i, y_x in enumerate(y_x_samples)] | |
plt.scatter(array(x_samples), array(y_samples)) | |
print([(x, y_samples[i]) for i, x in enumerate(x_samples)]) | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment