Last active
September 6, 2023 08:32
-
-
Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.
Extract tracer data from a RAMSES simulation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import gc | |
from pathlib import Path | |
from typing import List, Optional, Sequence, Tuple | |
import h5py | |
import numpy as np | |
import yt | |
from yt import mylog as logger | |
from yt.fields.derived_field import ValidateSpatial | |
yt.enable_parallelism() | |
logger.setLevel(10) | |
def setup_dataset(ds): | |
@yt.particle_filter(requires=["particle_family"], filtered_type="io") | |
def tracers(pfilter, data): | |
return data[(pfilter.filtered_type, "particle_family")] <= 0 | |
def _velocity_dispersion(field, data): | |
from itertools import product | |
new_field = np.zeros_like(data["gas", "velocity_x"]) | |
for i, j, k in product(*[range(2)] * 3): | |
v = 0 | |
w = np.ones_like(data["gas", "density"][i : i + 3, j : j + 3, k : k + 3]) | |
for kk in "xyz": | |
vel_block = data["gas", f"velocity_{kk}"][ | |
i : i + 3, j : j + 3, k : k + 3 | |
] | |
vmean = np.average(vel_block, weights=w, axis=(0, 1, 2)) | |
v += np.average((vel_block - vmean) ** 2, weights=w, axis=(0, 1, 2)) | |
new_field[i + 1, j + 1, k + 1] = np.sqrt(v) | |
return data.apply_units(new_field, data["gas", "velocity_x"].units) | |
ds.add_field( | |
("gas", "velocity_dispersion"), | |
_velocity_dispersion, | |
sampling_type="cell", | |
validators=[ValidateSpatial(ghost_zones=1)], | |
units="cm/s", | |
) | |
ds.add_particle_filter("tracers") | |
mesh_fields = [ | |
*(("gas", f"velocity_{k}") for k in "xyz"), | |
("gas", "density"), | |
("gas", "temperature"), | |
*( | |
("gas", f"{species}_number_density") | |
for species in ("HI", "HII", "HeI", "HeII") | |
), | |
("gas", "sound_speed"), | |
("gas", "velocity_dispersion"), | |
("gas", "cooling_total"), | |
("gas", "heating_total"), | |
("gas", "cooling_net"), | |
*(("gas", f"vorticity_{k}") for k in "xyz"), | |
] | |
for field in mesh_fields: | |
ds.add_mesh_sampling_particle_field(field, ptype="tracers") | |
def extract_data( | |
ds, | |
fields: List[Tuple[str, str]], | |
): | |
out_folder = (Path(ds.directory).parent / "subset").resolve() | |
name = f"{str(ds)}_region.h5" | |
out_filename = out_folder / name | |
found_fields = [] | |
if out_filename.exists(): | |
found_fields = [] | |
with h5py.File(out_filename, "r") as f: | |
for ft in f: | |
for fn in f[ft]: | |
found_fields.append((ft, fn)) | |
# Now check all fields have been registered | |
missing = [f for f in fields if f not in found_fields] | |
if len(missing) > 0: | |
logger.info( | |
"Found data file %s, but missing %s fields", out_filename, len(missing) | |
) | |
else: | |
logger.info("Found data file %s, all fields found", out_filename) | |
return | |
logger.info("Extracting data from %s", ds) | |
setup_dataset(ds) | |
ad = ds.all_data() | |
yt.funcs.mylog.info("Computing cell indices") | |
ad["tracers", "cell_index"] | |
yt.funcs.mylog.info("Writing dataset into %s", out_filename) | |
out_filename.parent.mkdir(parents=True, exist_ok=True) | |
ad.save_as_dataset(str(out_filename), fields=fields) | |
del ad, ds | |
gc.collect() | |
def main(argv: Optional[Sequence] = None) -> int: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-i", | |
"--simulation", | |
required=True, | |
type=str, | |
help="Path to the folder containing all outputs.", | |
) | |
parser.add_argument( | |
"--output-slice", | |
default=None, | |
help=( | |
"Slices of the output to consider (in the form istart:iend:istep), " | |
"useful when parallelizing manually (default: %(default)s)." | |
), | |
) | |
parser.add_argument( | |
"--bbox", | |
nargs=6, | |
default=[0, 0, 0, 1, 1, 1], | |
type=float, | |
help=( | |
"Bounding box to use for the region to extract in box units " | |
"(default: %(default)s)." | |
), | |
) | |
args = parser.parse_args(argv) | |
# Build field list | |
fields = [ | |
("tracers", "particle_family"), | |
("tracers", "particle_mass"), | |
("tracers", "particle_identity"), | |
*[("tracers", f"particle_position_{k}") for k in "xyz"], | |
*[("tracers", f"particle_velocity_{k}") for k in "xyz"], | |
("tracers", "particle_position"), | |
*[("tracers", f"cell_gas_velocity_{k}") for k in "xyz"], | |
("tracers", "cell_gas_density"), | |
("tracers", "cell_gas_temperature"), | |
("tracers", "cell_gas_HI_number_density"), | |
("tracers", "cell_gas_HII_number_density"), | |
("tracers", "cell_gas_HeI_number_density"), | |
("tracers", "cell_gas_HeII_number_density"), | |
("tracers", "cell_gas_sound_speed"), | |
("tracers", "cell_gas_velocity_dispersion"), | |
("tracers", "cell_gas_cooling_total"), | |
("tracers", "cell_gas_heating_total"), | |
("tracers", "cell_gas_cooling_net"), | |
*[("tracers", f"cell_gas_vorticity_{k}") for k in "xyz"], | |
] | |
# Loop over all datasets | |
simu = Path(args.simulation) | |
outputs = [ | |
out | |
for out in sorted( | |
list(simu.glob("output_?????")) + list(simu.glob("output_?????.tar.gz")) | |
) | |
if not ( | |
out.name.endswith(".tar.gz") | |
and out.with_name(out.name.replace(".tar.gz", "")).exists() | |
) | |
] | |
if args.output_slice is not None: | |
istart, iend, istep = ( | |
int(i) if i else None for i in args.output_slice.split(":") | |
) | |
sl = slice(istart, iend, istep) | |
outputs = outputs[sl] | |
pbar = yt.funcs.get_pbar("Constructing trajectory information", len(outputs)) | |
yt.set_log_level(40) | |
bbox = [args.bbox[0:3], args.bbox[3:6]] | |
for output in yt.parallel_objects(list(reversed(outputs))): | |
if output.name.endswith(".tar.gz"): | |
original_name = output.name.replace(".tar.gz", "") | |
try: | |
ds = yt.load_archive(output, original_name, mount_timeout=5, bbox=bbox) | |
except yt.utilities.exceptions.YTUnidentifiedDataType: | |
continue | |
ds.directory = str(output.parent / original_name) | |
else: | |
try: | |
ds = yt.load(output, bbox=bbox) | |
except yt.utilities.exceptions.YTUnidentifiedDataType: | |
continue | |
extract_data(ds, fields) | |
pbar.update(outputs.index(output)) | |
return 0 | |
if __name__ == "__main__": | |
import sys | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment