Skip to content

Instantly share code, notes, and snippets.

@HajimeKawahara
Last active December 31, 2021 03:17
Show Gist options
  • Select an option

  • Save HajimeKawahara/153319e36cbc9823fad9174ff10dd564 to your computer and use it in GitHub Desktop.

Select an option

Save HajimeKawahara/153319e36cbc9823fad9174ff10dd564 to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
from exojax.dynamics.rvfunc import rvf
import jax.numpy as jnp
from jax import random
from jax import vmap, jit
import matplotlib.pyplot as plt
to1=2459370 #fit offset
to2=2450000 #table Toffset
Toff=9370
dat = pd.read_csv("rv.txt",delimiter=",")
print(dat)
t=np.array(dat["TIME"]-Toff,dtype=np.float32)
rv=np.array(dat["RV"],dtype=np.float32)
err=np.array(dat["ERROR"],dtype=np.float32)
ps=np.load("npz/savepos.npz",allow_pickle=True)["arr_0"][0]
print(np.median(ps["P"]),np.std(ps["P"]))
secosw=ps["secosw"]
sesinw=ps["sesinw"]
eps=secosw**2+sesinw**2
omegaAps=jnp.arctan2(sesinw,secosw) #
def est(val,N=3):
per=np.percentile(val,[5,95])
med=np.round(np.median(val),N)
per0=np.round(per[0]-med,N)
per1=np.round(per[1]-med,N)
print(str(med)+'_{'+str(per0)+'}^{'+str(per1)+'}')
print(est(eps))
N=4000
ran1=np.random.rand(N)*2.0-1.0
ran2=np.random.rand(N)*2.0-1.0
eran=ran1**2+ran2**2
fig=plt.figure()
ax=fig.add_subplot(111)
plt.hist(eran,bins=40,alpha=0.3,label="random")
plt.hist(eps,bins=40,alpha=0.3,label="posterior")
plt.xlim(0.0,1.0)
plt.xlabel("e")
plt.legend()
plt.savefig("npz/e.png")
import tqdm
fig=plt.figure(figsize=(20,20))
ax1=fig.add_subplot(311)
ax2=fig.add_subplot(312)
ax3=fig.add_subplot(313)
ax2.set_xlim(-367.0,-300.0)
ax3.set_xlim(-75.0,120.0)
ax1.errorbar(t,rv,yerr=err,ls="none")
ax2.errorbar(t,rv,yerr=err,ls="none")
ax3.errorbar(t,rv,yerr=err,ls="none")
ax1.plot(t,rv,"o")
ax2.plot(t,rv,"o")
ax3.plot(t,rv,"o")
tpre=jnp.linspace(t[0],t[-1],3600)
for i in tqdm.tqdm(range(0,len(ps["P"][::10]))):
e=eps[i]
T0=ps["T0"][i]
P=ps["P"][i]
omegaA=omegaAps[i]
Ksini=ps["Ksini"][i]
Vsys=ps["Vsys"][i]
model=rvf(tpre,T0,P,e,omegaA,Ksini,Vsys)
ax1.plot(tpre,model,alpha=0.05,color="gray")
ax2.plot(tpre,model,alpha=0.05,color="gray")
ax3.plot(tpre,model,alpha=0.05,color="gray")
plt.savefig("npz/results.png", bbox_inches="tight", pad_inches=0.0)
#corner plot
import arviz
rc = {
"plot.max_subplots": 250,
}
#arviz.style.use("arviz-darkgrid")
arviz.rcParams.update(rc)
axes=arviz.plot_pair(ps,kind='kde',divergences=False,marginals=True,textsize=18)
fig = axes.ravel()[0].figure
fig.savefig("npz/cornerall.png", bbox_inches="tight", pad_inches=0.0)
fig.savefig("npz/cornerall.pdf", bbox_inches="tight", pad_inches=0.0)
plt.show()
import pandas as pd
import numpy as np
from exojax.dynamics.rvfunc import rvf
import jax.numpy as jnp
from jax import random
from jax import vmap, jit
import matplotlib.pyplot as plt
to1=2459370 #fit offset
to2=2450000 #table Toffset
Toff=9370
dat = pd.read_csv("rv.txt",delimiter=",")
print(dat)
t=np.array(dat["TIME"]-Toff,dtype=np.float32)
rv=np.array(dat["RV"],dtype=np.float32)
err=np.array(dat["ERROR"],dtype=np.float32)
P=10.77
Ksini=3.75
Tp=0.43
e=0.32
omegaA=-0.53
Vsys=-0.32
model=rvf(t,Tp,P,e,omegaA,Ksini,Vsys)
phase=np.mod((t-Tp),P)/P
plt.plot(phase,model,"+")
plt.plot(phase,rv,".")
plt.savefig("best.png")
#import sys
#sys.exit()
#HMC-NUTS FITTING PART
#######################################################
import numpyro.distributions as dist
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.diagnostics import hpdi
def model_c(t1,y1,e1):
# lnP=numpyro.sample("lnP", dist.Uniform(jnp.log10(0.5),jnp.log10(10000.)))
# P=10**lnP
P=numpyro.sample("P", dist.Uniform(8.0,12.0))
Ksini=numpyro.sample('Ksini', dist.Exponential(0.1)) #should be modified Jeffery later
T0 = numpyro.sample('T0', dist.Uniform(-6.0,6.0))
sesinw = numpyro.sample('sesinw', dist.Uniform(-1.0,1.0))
secosw = numpyro.sample('secosw', dist.Uniform(-1.0,1.0))
etmp=sesinw**2+secosw**2
e=jnp.where(etmp>1.0,1.0,etmp)
omegaA=jnp.arctan2(sesinw,secosw) #
# sigmajit=numpyro.sample('sigmajit', dist.Uniform(0.1,100.0))
sigmajit=numpyro.sample('sigmajit', dist.Exponential(1.0))
Vsys = numpyro.sample('Vsys', dist.Uniform(-10,10.0))
mu=rvf(t1,T0,P,e,omegaA,Ksini,Vsys)
errall=jnp.sqrt(e1**2+sigmajit**2)
numpyro.sample("y1", dist.Normal(mu, errall), obs=y1) #-
#Running a HMC-NUTS
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 2000, 4000
#num_warmup, num_samples = 100, 300
kernel = NUTS(model_c)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key_, t1=t, y1=rv, e1=err)
mcmc.print_summary()
print("end HMC")
#Post-processing
posterior_sample = mcmc.get_samples()
np.savez("npz/savepos.npz",[posterior_sample])
@HajimeKawahara
Copy link
Copy Markdown
Author

require exojax

@HajimeKawahara
Copy link
Copy Markdown
Author

pip install exojax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment