Skip to content

Instantly share code, notes, and snippets.

@chrishavlin
Created December 20, 2022 19:55
Show Gist options
  • Save chrishavlin/4c8e55179299e60d88dc5129e32cc676 to your computer and use it in GitHub Desktop.
Save chrishavlin/4c8e55179299e60d88dc5129e32cc676 to your computer and use it in GitHub Desktop.
diff --git a/yt/_typing.py b/yt/_typing.py
index 0062c63dd..15fcd8a6f 100644
--- a/yt/_typing.py
+++ b/yt/_typing.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import unyt as un
from numpy import ndarray
@@ -28,3 +28,10 @@ Unit = Union[un.Unit, str]
# types that can be converted to un.unyt_quantity
Quantity = Union[un.unyt_quantity, Tuple[float, Unit]]
+
+# types for stream loaders
+StreamGridCallable = Callable[[Any, FieldName], ndarray]
+UniformGridData = Dict[
+ AnyFieldKey,
+ Union[ndarray, StreamGridCallable, Tuple[Union[ndarray, StreamGridCallable], Unit]],
+]
diff --git a/yt/loaders.py b/yt/loaders.py
index c32b31f1c..6e5ece628 100644
--- a/yt/loaders.py
+++ b/yt/loaders.py
@@ -22,7 +22,7 @@ from yt._maintenance.deprecation import (
future_positional_only,
issue_deprecation_warning,
)
-from yt._typing import AnyFieldKey, AxisOrder, FieldKey
+from yt._typing import AnyFieldKey, AxisOrder, FieldKey, UniformGridData
from yt.data_objects.static_output import Dataset
from yt.funcs import levenshtein_distance
from yt.sample_data.api import lookup_on_disk_data
@@ -174,7 +174,7 @@ def _sanitize_axis_order_args(
def load_uniform_grid(
- data: Dict[AnyFieldKey, np.ndarray],
+ data: UniformGridData,
domain_dimensions,
length_unit=None,
bbox=None,
@@ -293,14 +293,20 @@ def load_uniform_grid(
sfh = StreamDictFieldHandler()
+ # record the size of any in-memory fields. This simplifies type checking.
+ in_mem_fields_shape: Dict[AnyFieldKey, Tuple[int, ...]] = {}
+ for ky, val in data.items():
+ if isinstance(val, np.ndarray):
+ in_mem_fields_shape[ky] = val.shape
+
if number_of_particles > 0:
particle_types = set_particle_types(data)
# Used much further below.
pdata: Dict[Union[str, FieldKey], Any] = {
"number_of_particles": number_of_particles
}
- for key in list(data.keys()):
- if len(data[key].shape) == 1 or key[0] == "io":
+ for key in in_mem_fields_shape.keys():
+ if len(in_mem_fields_shape[key]) == 1 or key[0] == "io":
field: FieldKey
if not isinstance(key, tuple):
field = ("io", key)
@@ -313,15 +319,15 @@ def load_uniform_grid(
particle_types = {}
if nprocs > 1:
- temp = {}
- new_data = {} # type: ignore [var-annotated]
+ temp: Dict[AnyFieldKey, np.ndarray] = {}
+ new_data: Dict[int, Dict[AnyFieldKey, np.ndarray]] = {}
for key in data.keys():
- psize = get_psize(np.array(data[key].shape), nprocs)
+ psize = get_psize(np.array(in_mem_fields_shape[key]), nprocs)
grid_left_edges, grid_right_edges, shapes, slices = decompose_array(
- data[key].shape, psize, bbox
+ in_mem_fields_shape[key], psize, bbox
)
grid_dimensions = np.array([shape for shape in shapes], dtype="int32")
- temp[key] = [data[key][slice] for slice in slices]
+ temp[key] = [data[key][slice] for slice in slices] # type: ignore
for gid in range(nprocs):
new_data[gid] = {}
for key in temp.keys():
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment