Skip to content

Instantly share code, notes, and snippets.

@cphyc
Last active September 6, 2023 08:32
Show Gist options
  • Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.
Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.
Extract tracer data from a RAMSES simulation.
import argparse
import gc
from pathlib import Path
from typing import List, Optional, Sequence, Tuple
import h5py
import numpy as np
import yt
from yt import mylog as logger
from yt.fields.derived_field import ValidateSpatial
yt.enable_parallelism()
logger.setLevel(10)
def setup_dataset(ds):
@yt.particle_filter(requires=["particle_family"], filtered_type="io")
def tracers(pfilter, data):
return data[(pfilter.filtered_type, "particle_family")] <= 0
def _velocity_dispersion(field, data):
from itertools import product
new_field = np.zeros_like(data["gas", "velocity_x"])
for i, j, k in product(*[range(2)] * 3):
v = 0
w = np.ones_like(data["gas", "density"][i : i + 3, j : j + 3, k : k + 3])
for kk in "xyz":
vel_block = data["gas", f"velocity_{kk}"][
i : i + 3, j : j + 3, k : k + 3
]
vmean = np.average(vel_block, weights=w, axis=(0, 1, 2))
v += np.average((vel_block - vmean) ** 2, weights=w, axis=(0, 1, 2))
new_field[i + 1, j + 1, k + 1] = np.sqrt(v)
return data.apply_units(new_field, data["gas", "velocity_x"].units)
ds.add_field(
("gas", "velocity_dispersion"),
_velocity_dispersion,
sampling_type="cell",
validators=[ValidateSpatial(ghost_zones=1)],
units="cm/s",
)
ds.add_particle_filter("tracers")
mesh_fields = [
*(("gas", f"velocity_{k}") for k in "xyz"),
("gas", "density"),
("gas", "temperature"),
*(
("gas", f"{species}_number_density")
for species in ("HI", "HII", "HeI", "HeII")
),
("gas", "sound_speed"),
("gas", "velocity_dispersion"),
("gas", "cooling_total"),
("gas", "heating_total"),
("gas", "cooling_net"),
*(("gas", f"vorticity_{k}") for k in "xyz"),
]
for field in mesh_fields:
ds.add_mesh_sampling_particle_field(field, ptype="tracers")
def extract_data(
ds,
fields: List[Tuple[str, str]],
):
out_folder = (Path(ds.directory).parent / "subset").resolve()
name = f"{str(ds)}_region.h5"
out_filename = out_folder / name
found_fields = []
if out_filename.exists():
found_fields = []
with h5py.File(out_filename, "r") as f:
for ft in f:
for fn in f[ft]:
found_fields.append((ft, fn))
# Now check all fields have been registered
missing = [f for f in fields if f not in found_fields]
if len(missing) > 0:
logger.info(
"Found data file %s, but missing %s fields", out_filename, len(missing)
)
else:
logger.info("Found data file %s, all fields found", out_filename)
return
logger.info("Extracting data from %s", ds)
setup_dataset(ds)
ad = ds.all_data()
yt.funcs.mylog.info("Computing cell indices")
ad["tracers", "cell_index"]
yt.funcs.mylog.info("Writing dataset into %s", out_filename)
out_filename.parent.mkdir(parents=True, exist_ok=True)
ad.save_as_dataset(str(out_filename), fields=fields)
del ad, ds
gc.collect()
def main(argv: Optional[Sequence] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--simulation",
required=True,
type=str,
help="Path to the folder containing all outputs.",
)
parser.add_argument(
"--output-slice",
default=None,
help=(
"Slices of the output to consider (in the form istart:iend:istep), "
"useful when parallelizing manually (default: %(default)s)."
),
)
parser.add_argument(
"--bbox",
nargs=6,
default=[0, 0, 0, 1, 1, 1],
type=float,
help=(
"Bounding box to use for the region to extract in box units "
"(default: %(default)s)."
),
)
args = parser.parse_args(argv)
# Build field list
fields = [
("tracers", "particle_family"),
("tracers", "particle_mass"),
("tracers", "particle_identity"),
*[("tracers", f"particle_position_{k}") for k in "xyz"],
*[("tracers", f"particle_velocity_{k}") for k in "xyz"],
("tracers", "particle_position"),
*[("tracers", f"cell_gas_velocity_{k}") for k in "xyz"],
("tracers", "cell_gas_density"),
("tracers", "cell_gas_temperature"),
("tracers", "cell_gas_HI_number_density"),
("tracers", "cell_gas_HII_number_density"),
("tracers", "cell_gas_HeI_number_density"),
("tracers", "cell_gas_HeII_number_density"),
("tracers", "cell_gas_sound_speed"),
("tracers", "cell_gas_velocity_dispersion"),
("tracers", "cell_gas_cooling_total"),
("tracers", "cell_gas_heating_total"),
("tracers", "cell_gas_cooling_net"),
*[("tracers", f"cell_gas_vorticity_{k}") for k in "xyz"],
]
# Loop over all datasets
simu = Path(args.simulation)
outputs = [
out
for out in sorted(
list(simu.glob("output_?????")) + list(simu.glob("output_?????.tar.gz"))
)
if not (
out.name.endswith(".tar.gz")
and out.with_name(out.name.replace(".tar.gz", "")).exists()
)
]
if args.output_slice is not None:
istart, iend, istep = (
int(i) if i else None for i in args.output_slice.split(":")
)
sl = slice(istart, iend, istep)
outputs = outputs[sl]
pbar = yt.funcs.get_pbar("Constructing trajectory information", len(outputs))
yt.set_log_level(40)
bbox = [args.bbox[0:3], args.bbox[3:6]]
for output in yt.parallel_objects(list(reversed(outputs))):
if output.name.endswith(".tar.gz"):
original_name = output.name.replace(".tar.gz", "")
try:
ds = yt.load_archive(output, original_name, mount_timeout=5, bbox=bbox)
except yt.utilities.exceptions.YTUnidentifiedDataType:
continue
ds.directory = str(output.parent / original_name)
else:
try:
ds = yt.load(output, bbox=bbox)
except yt.utilities.exceptions.YTUnidentifiedDataType:
continue
extract_data(ds, fields)
pbar.update(outputs.index(output))
return 0
if __name__ == "__main__":
import sys
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment