Skip to content

Instantly share code, notes, and snippets.

@WillianFuks
Created December 17, 2020 02:46
Show Gist options
  • Select an option

  • Save WillianFuks/4d2a8a38aa3c77f089c0055455771057 to your computer and use it in GitHub Desktop.

Select an option

Save WillianFuks/4d2a8a38aa3c77f089c0055455771057 to your computer and use it in GitHub Desktop.
observed_stddev, observed_initial = (tf.convert_to_tensor(value=1, dtype=tf.float32),
tf.convert_to_tensor(value=0., dtype=tf.float32))
level_scale_prior = tfd.LogNormal(loc=tf.math.log(0.05 * observed_stddev), scale=1, name='level_scale_prior')
initial_state_prior = tfd.MultivariateNormalDiag(loc=observed_initial[..., tf.newaxis],
scale_diag=(tf.abs(observed_initial) + observed_stddev)[..., tf.newaxis],
name='initial_level_prior')
ll_ssm = tfp.sts.LocalLevelStateSpaceModel(100, initial_state_prior=initial_state_prior, level_scale=level_scale_prior.sample())
ll_ssm_sample = np.squeeze(ll_ssm.sample().numpy())
x0 = 100 * np.random.rand(100)
x1 = 90 * np.random.rand(100)
y = 1.2 * x0 + 0.9 * x1 + ll_ssm_sample
data = pd.DataFrame({'x0': x0, 'x1': x1, 'y': y}, columns=['y', 'x0', 'x1'])
data.plot()
plt.axvline(69, linestyle='--', color='k')
plt.legend();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment