Skip to content

Instantly share code, notes, and snippets.

@mvbattan
Created September 7, 2018 05:15
Show Gist options
  • Save mvbattan/96448873772147bb7a20aade09dd19da to your computer and use it in GitHub Desktop.
Save mvbattan/96448873772147bb7a20aade09dd19da to your computer and use it in GitHub Desktop.
Ejercicio 3 - Generator
# Taken from https://stackoverflow.com/a/44308018
from scipy.stats import norm, truncnorm
import matplotlib.pyplot as plt
from numpy import array
from math import pi, cos
SAMPLE_SIZE = 200
def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
return truncnorm(
(low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
def m(x):
if -1 <= x < -0.5:
return ((x + 2) ** 2 ) / 2.0
if -0.5 <= x <= 0:
return x / 2.0 + 0.875
if 0 < x <= 0.5:
return -5 * (x - 0.2) ** 2 + 1.075
if 0.5 < x <= 1:
return x + 0.125
def get_conditional_sd(x):
return 0.2 - 0.1 * cos(2 * pi * x)
def get_y_rvs(x):
Y_x = norm(0, get_conditional_sd(x))
return Y_x.rvs()
def main():
X = get_truncated_normal(mean=0, sd=1, low=-1, upp=1)
x_samples = X.rvs(SAMPLE_SIZE).tolist()
m_samples = list(map(m, x_samples))
y_x_samples = list(map(get_y_rvs, x_samples))
y_samples = [y_x + m_samples[i] for i, y_x in enumerate(y_x_samples)]
plt.scatter(array(x_samples), array(y_samples))
print([(x, y_samples[i]) for i, x in enumerate(x_samples)])
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment