Skip to content

Instantly share code, notes, and snippets.

@cphyc
Created March 29, 2023 13:40
Show Gist options
  • Save cphyc/7477723d247078dd0388d07f5689711f to your computer and use it in GitHub Desktop.
Save cphyc/7477723d247078dd0388d07f5689711f to your computer and use it in GitHub Desktop.
Extract data from a simulation into .h5 file, then load them and gather them using xarray.
import argparse
import gc
from functools import wraps
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import h5py
import joblib
import numpy as np
import pynbody
import yt
from astrophysics_toolset.utilities.logging import logger
from yt.fields.derived_field import ValidateSpatial
from yt.utilities.parallel_tools.parallel_analysis_interface import communication_system
yt.enable_parallelism()
logger.setLevel(10)
def setup_dataset(
ds, iord_baryons: np.ndarray
):
@yt.particle_filter(requires=["particle_family"], filtered_type="io")
def tracers(pfilter, data):
return data[(pfilter.filtered_type, "particle_family")] <= 0
@yt.particle_filter(requires=["particle_identity"], filtered_type="tracers")
def baryon_tracers(pfilter, data):
return np.in1d(data[(pfilter.filtered_type, "particle_identity")], iord_baryons)
@yt.particle_filter(requires=["particle_identity"], filtered_type="DM")
def selected_DM(pfilter, data):
return np.in1d(data[(pfilter.filtered_type, "particle_identity")], iord_dm)
def _corresponding_DM_ids(field, data):
ids = data[field.name[0], "particle_identity"].astype(int)
corresponding_iord_dm
ind = np.searchsorted(iord_baryons, ids)
return data.apply_units(corresponding_iord_dm[ind], "1")
def _velocity_dispersion(field, data):
from itertools import product
new_field = np.zeros_like(data["gas", "velocity_x"])
for i, j, k in product(*[range(2)] * 3):
v = 0
w = np.ones_like(data["gas", "density"][i : i + 3, j : j + 3, k : k + 3])
for kk in "xyz":
vel_block = data["gas", f"velocity_{kk}"][
i : i + 3, j : j + 3, k : k + 3
]
vmean = np.average(vel_block, weights=w, axis=(0, 1, 2))
v += np.average((vel_block - vmean) ** 2, weights=w, axis=(0, 1, 2))
# v += np.average((vel_block - np.mean(vel_block))**2, weights=mass_block, axis=(0, 1, 2))
# v += vel_block.var(axis=(0, 1, 2))
new_field[i + 1, j + 1, k + 1] = np.sqrt(v)
return data.apply_units(new_field, data["gas", "velocity_x"].units)
ds.add_field(
("gas", "velocity_dispersion"),
_velocity_dispersion,
sampling_type="cell",
validators=[ValidateSpatial(ghost_zones=1)],
units="cm/s",
)
ds.add_particle_filter("tracers")
ds.add_particle_filter("baryon_tracers")
ds.add_particle_filter("selected_DM")
mesh_fields = [
*(("gas", f"velocity_{k}") for k in "xyz"),
("gas", "density"),
("gas", "temperature"),
*(
("gas", f"{species}_number_density")
for species in ("HI", "HII", "HeI", "HeII")
),
("gas", "sound_speed"),
("gas", "velocity_dispersion"),
("gas", "cooling_total"),
("gas", "heating_total"),
("gas", "cooling_net"),
]
for field in mesh_fields:
ds.add_mesh_sampling_particle_field(field, ptype="baryon_tracers")
# for k in "xyz":
# ds.add_mesh_sampling_particle_field(
# ("gas", f"velocity_{k}"), ptype="baryon_tracers"
# )
# ds.add_mesh_sampling_particle_field(("gas", "density"), ptype="baryon_tracers")
# ds.add_mesh_sampling_particle_field(("gas", "temperature"), ptype="baryon_tracers")
# ds.add_mesh_sampling_particle_field(
# ("gas", "HI_number_density"), ptype="baryon_tracers"
# )
# ds.add_mesh_sampling_particle_field(
# ("gas", "HII_number_density"), ptype="baryon_tracers"
# )
# ds.add_mesh_sampling_particle_field(
# ("gas", "HeI_number_density"), ptype="baryon_tracers"
# )
# ds.add_mesh_sampling_particle_field(
# ("gas", "HeII_number_density"), ptype="baryon_tracers"
# )
ds.add_field(
("baryon_tracers", "DM_identity"),
function=_corresponding_DM_ids,
sampling_type="particle",
units="1",
)
def extract_data(
ds,
fields: List[Tuple[str, str]],
iord_baryons: np.ndarray,
# corresponding_iord_dm: np.ndarray,
# iord_dm: np.ndarray,
):
out_folder = (Path(ds.directory).parent / "subset").resolve()
name = f"{str(ds)}_region.h5"
out_filename = out_folder / name
found_fields = []
if out_filename.exists():
found_fields = []
with h5py.File(out_filename, "r") as f:
for ft in f:
for fn in f[ft]:
found_fields.append((ft, fn))
# Now check all fields have been registered
missing = [f for f in fields if f not in found_fields]
if len(missing) > 0:
logger.info(
"Found data file %s, but missing %s fields", out_filename, len(missing)
)
else:
logger.info("Found data file %s, all fields found", out_filename)
return
logger.info("Extracting data from %s", ds)
setup_dataset(ds, iord_baryons) #, corresponding_iord_dm, iord_dm)
ad = ds.all_data()
yt.funcs.mylog.info("Computing cell indices")
ad["baryon_tracers", "cell_index"]
yt.funcs.mylog.info("Writing dataset into %s", out_filename)
out_filename.parent.mkdir(parents=True, exist_ok=True)
ad.save_as_dataset(str(out_filename), fields=fields)
del ad, ds
gc.collect()
def main(argv: Optional[Sequence] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--simulation",
type=str,
help="Path to the folder containing all outputs.",
)
parser.add_argument(
"-z",
"--z-target",
default=2,
type=float,
help="Redshift where to extract the Lagrangian patch (default: %(default)s).",
)
parser.add_argument(
"--R200-fraction",
default=2,
type=float,
help=(
"Maximum distance to look for particles, "
"in units of R200 (default: %(default)s)."
),
)
parser.add_argument(
"--output-slice",
default=None,
help=(
"Output slice to consider (in the form istart:iend:istep), "
"useful when parallelizing manually (default: %(default)s)."
),
)
args = parser.parse_args(argv)
# Find baryons & DM ids in the main galaxy
# memory = joblib.Memory(Path(args.simulation) / "cache")
# build_baryonic_patch_cached = memory.cache(build_baryonic_patch)
# iord_baryons, corresponding_iord_baryons, iord_dm = build_baryonic_patch_cached(
# path=args.simulation,
# fR200=args.R200_fraction,
# z_target=args.z_target,
# )
iord_baryons = np.loadtxt(Path(__file__).parent.parent / "nut_ids.csv_but_for_real.csv")
# Build field list
fields = []
for __ in ("baryon_tracers", ): # "selected_DM"):
fields.extend(
[
(__, "particle_family"),
(__, "particle_mass"),
(__, "particle_identity"),
*[(__, f"particle_position_{k}") for k in "xyz"],
*[(__, f"particle_velocity_{k}") for k in "xyz"],
(__, "particle_position"),
]
)
if __ == "baryon_tracers":
fields.extend(
[
*[(__, f"cell_gas_velocity_{k}") for k in "xyz"],
(__, "cell_gas_density"),
(__, "cell_gas_temperature"),
(__, "cell_gas_HI_number_density"),
(__, "cell_gas_HII_number_density"),
(__, "cell_gas_HeI_number_density"),
(__, "cell_gas_HeII_number_density"),
(__, "DM_identity"),
(__, "cell_gas_sound_speed"),
(__, "cell_gas_velocity_dispersion"),
(__, "cell_gas_cooling_total"),
(__, "cell_gas_heating_total"),
(__, "cell_gas_cooling_net"),
]
)
# Loop over all datasets
simu = Path(args.simulation)
outputs = [
out
for out in sorted(
list(simu.glob("output_?????")) + list(simu.glob("output_?????.tar.gz"))
)
if not (
out.name.endswith(".tar.gz")
and out.with_name(out.name.replace(".tar.gz", "")).exists()
)
]
if args.output_slice is not None:
istart, iend, istep = (int(i) if i else None for i in args.output_slice.split(":"))
sl = slice(istart, iend, istep)
outputs = outputs[sl]
pbar = yt.funcs.get_pbar("Constructing trajectory information", len(outputs))
yt.set_log_level(40)
for output in yt.parallel_objects(list(reversed(outputs))):
bbox = [[0.49] * 3, [0.51] * 3]
if output.name.endswith(".tar.gz"):
original_name = output.name.replace(".tar.gz", "")
try:
ds = yt.load_archive(output, original_name, mount_timeout=5, bbox=bbox)
except yt.utilities.exceptions.YTUnidentifiedDataType:
continue
ds.directory = str(output.parent / original_name)
else:
try:
ds = yt.load(output, bbox=bbox)
except yt.utilities.exceptions.YTUnidentifiedDataType:
continue
extract_data(ds, fields, iord_baryons)
pbar.update(outputs.index(output))
if __name__ == "__main__":
main()
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment