Install notes from BDL2022f env install on 2022-11-28
- For CUDA TOOLKIT 11.3, which can be used on older devices but may not be optimal
- set up basic conda env without any torch or jax packages, via
conda env create -f bdl_2022f.yml
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
import jax.nn | |
def calc_trans_mat(x_TD, r_KD, p_KK): | |
''' Compute transition matrix for each timestep | |
Args |
import jax | |
import jax.numpy as jnp | |
if __name__ == '__main__': | |
print("jax.devices()") | |
print(jax.devices()) | |
a = jnp.asarray([[1.0, 2.0, 3.0], [4., 5., 6.]]) | |
b = jnp.asarray([[1.0, 2.0], [3.0, 4.0], [5., 6.]]) |
Install notes from BDL2022f env install on 2022-11-28
conda env create -f bdl_2022f.yml
import numpy as np | |
import scipy.stats | |
import matplotlib.pyplot as plt | |
from statsmodels.distributions.empirical_distribution import ECDF | |
def create_transform_func_to_match_source(target_x_ND, src_x_MD, n_quantiles=1000): | |
''' |
''' VI for Poisson Normal | |
Model | |
----- | |
Latent variable z is drawn from a Normal prior: z ~ Normal( 40, 10) | |
Data y is drawn iid from a Poisson likelihood: y_n ~ Poisson(z) | |
Approx Posterior | |
---------------- | |
Posterior on z is assumed to be Normal with unknown mean and stddev |
window_size | sample_id | accuracy | |
---|---|---|---|
5.0 | 0 | 0.6236094882645041 | |
5.0 | 1 | 0.593111865845944 | |
5.0 | 2 | 0.6060493252962028 | |
5.0 | 3 | 0.6342719738873018 | |
5.0 | 4 | 0.6259239448289695 | |
5.0 | 5 | 0.5623114809268821 | |
5.0 | 6 | 0.6054087015122116 | |
5.0 | 7 | 0.5807796285836049 | |
5.0 | 8 | 0.5818560349582886 |
vals_float32 = np.logspace(0, 5, dtype=np.float32) | |
vals_float64 = np.logspace(0, 5, dtype=np.float64) | |
## Pretty-print output of array so each float takes same num chars | |
def pprint_arr(arr, n_per_line=6): | |
for s in range(0, arr.size, n_per_line): | |
chunk = arr[s:s+n_per_line] | |
print(" ".join(["%10s" % np.format_float_scientific(x, precision=2, unique=False, exp_digits=3) for x in chunk])) | |
print() |