Skip to content

Instantly share code, notes, and snippets.

@streeto
Created June 30, 2014 20:58
Show Gist options
  • Save streeto/7be5ff5eabb08994d1b8 to your computer and use it in GitHub Desktop.
Save streeto/7be5ff5eabb08994d1b8 to your computer and use it in GitHub Desktop.
Correct behaviour in velocity_fix using 1-d fluxes
'''
Created on Jun 27, 2013
@author: andre
'''
import numpy as np
__all__ = ['SpectraVelocityFixer']
################################################################################
class SpectraVelocityFixer(object):
def __init__(self, l_obs, v_0, v_d, nproc=None):
self.l_obs = np.ascontiguousarray(l_obs, 'float64')
if not np.allclose(self.l_obs, np.linspace(self.l_obs.min(), self.l_obs.max(), len(self.l_obs))):
raise ValueError('l_obs is not equally spaced.')
if np.isscalar(v_0):
self.v_0 = np.array([v_0])
else:
self.v_0 = np.asarray(v_0)
if np.isscalar(v_d):
self.v_d = np.array([v_d])
else:
self.v_d = np.asarray(v_d)
self.nproc = nproc
def _params_flag(self, flux, err, flag, v_d, fill, fill_val):
if flux.ndim == 1:
N_spec = 1
else:
N_spec = flux.shape[1]
for i in xrange(N_spec):
yield (self.l_obs,
np.ascontiguousarray(flux[:,i], 'float64'),
np.ascontiguousarray(err[:,i], 'float64'),
np.ascontiguousarray(flag[:,i], 'bool'),
self.v_0[i], v_d[i], fill, fill_val)
def _paramsFlagged(self, flux, err, flag, v_d, fill, fill_val):
if flux.ndim == 1:
N_spec = 1
else:
N_spec = flux.shape[1]
for i in xrange(N_spec):
yield (self.l_obs,
np.ascontiguousarray(flux[:,i], 'float64'),
np.ascontiguousarray(err[:,i], 'float64'),
np.ascontiguousarray(flag[:,i], 'bool'),
self.v_0[i], v_d[i], fill, fill_val)
def _params(self, flux, v_d, fill, fill_val):
if flux.ndim == 1:
N_spec = 1
else:
N_spec = flux.shape[1]
for i in xrange(N_spec):
yield (self.l_obs,
np.ascontiguousarray(flux[:,i], 'float64'),
self.v_0[i], v_d[i], fill, fill_val)
def _getVd(self, target_vd):
m = self.v_d < target_vd
vd_fix = np.zeros_like(self.v_d)
vd_fix[m] = np.sqrt(target_vd**2 - self.v_d[m]**2)
return vd_fix
def _process(self, func, params):
if self.nproc != 1:
import multiprocessing
pool = multiprocessing.Pool(self.nproc)
out = pool.map(func, params)
else:
out = [func(args) for args in params]
return out
def fixFlagged(self, flux, err, flag, target_vd=0.0, fill='nearest', fill_val=0.0):
# Fix the velocity dispersion only if needed.
shape = (len(self.l_obs),) + self.v_0.shape
one_d = False
if flux.shape == self.l_obs.shape:
one_d = True
flux = flux[:, np.newaxis]
elif flux.shape != shape:
raise ValueError('flux has an incorrect shape: %s. Should be %s.' % flux.shape, shape)
vd = self._getVd(target_vd)
params = self._paramsFlagged(flux, err, flag, vd, fill, fill_val)
out = self._process(_fix_spectra_flag, params)
flux_s = np.empty_like(flux)
err_s = np.empty_like(err)
flag_s = np.empty_like(flag)
if one_d:
flux_s, err_s, flag_s = out[0], out[2], out[2]
else:
for i, z in enumerate(out):
flux_s[:,i], err_s[:,i], flag_s[:,i] = z[0], z[1], z[2]
return flux_s, err_s, flag_s
def fix(self, flux, target_vd=0.0, fill='nearest', fill_val=0.0):
# Fix the velocity dispersion only if needed.
shape = (len(self.l_obs),) + self.v_0.shape
one_d = False
if flux.shape == self.l_obs.shape:
one_d = True
flux = flux[:, np.newaxis]
elif flux.shape != shape:
raise ValueError('flux has an incorrect shape: %s. Should be %s.' % flux.shape, shape)
vd = self._getVd(target_vd)
params = self._params(flux, vd, fill, fill_val)
out = self._process(_fix_spectra, params)
if one_d:
return out
else:
return np.array(out).T
################################################################################
def _fix_spectra_flag(args):
from gauss_smooth import gaussVelocitySmoothFlag # @UnresolvedImport
l_obs, flux, err, flag, v_0, v_d, fill, fill_val = args
return gaussVelocitySmoothFlag(l_obs, flux, err, flag,
-v_0, v_d, n_sig=4, n_u=21, fill=fill, fill_val=fill_val, flag_threshold=0.5)
################################################################################
################################################################################
def _fix_spectra(args):
from gauss_smooth import gaussVelocitySmooth # @UnresolvedImport
l_obs, flux, v_0, v_d, fill, fill_val = args
return gaussVelocitySmooth(l_obs, flux, -v_0, v_d, n_sig=4, n_u=21, fill=fill, fill_val=fill_val)
################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment