Created
December 20, 2022 19:55
-
-
Save chrishavlin/4c8e55179299e60d88dc5129e32cc676 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/yt/_typing.py b/yt/_typing.py | |
index 0062c63dd..15fcd8a6f 100644 | |
--- a/yt/_typing.py | |
+++ b/yt/_typing.py | |
@@ -1,4 +1,4 @@ | |
-from typing import List, Optional, Tuple, Union | |
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import unyt as un | |
from numpy import ndarray | |
@@ -28,3 +28,10 @@ Unit = Union[un.Unit, str] | |
# types that can be converted to un.unyt_quantity | |
Quantity = Union[un.unyt_quantity, Tuple[float, Unit]] | |
+ | |
+# types for stream loaders | |
+StreamGridCallable = Callable[[Any, FieldName], ndarray] | |
+UniformGridData = Dict[ | |
+ AnyFieldKey, | |
+ Union[ndarray, StreamGridCallable, Tuple[Union[ndarray, StreamGridCallable], Unit]], | |
+] | |
diff --git a/yt/loaders.py b/yt/loaders.py | |
index c32b31f1c..6e5ece628 100644 | |
--- a/yt/loaders.py | |
+++ b/yt/loaders.py | |
@@ -22,7 +22,7 @@ from yt._maintenance.deprecation import ( | |
future_positional_only, | |
issue_deprecation_warning, | |
) | |
-from yt._typing import AnyFieldKey, AxisOrder, FieldKey | |
+from yt._typing import AnyFieldKey, AxisOrder, FieldKey, UniformGridData | |
from yt.data_objects.static_output import Dataset | |
from yt.funcs import levenshtein_distance | |
from yt.sample_data.api import lookup_on_disk_data | |
@@ -174,7 +174,7 @@ def _sanitize_axis_order_args( | |
def load_uniform_grid( | |
- data: Dict[AnyFieldKey, np.ndarray], | |
+ data: UniformGridData, | |
domain_dimensions, | |
length_unit=None, | |
bbox=None, | |
@@ -293,14 +293,20 @@ def load_uniform_grid( | |
sfh = StreamDictFieldHandler() | |
+ # record the size of any in-memory fields. This simplifies type checking. | |
+ in_mem_fields_shape: Dict[AnyFieldKey, Tuple[int, ...]] = {} | |
+ for ky, val in data.items(): | |
+ if isinstance(val, np.ndarray): | |
+ in_mem_fields_shape[ky] = val.shape | |
+ | |
if number_of_particles > 0: | |
particle_types = set_particle_types(data) | |
# Used much further below. | |
pdata: Dict[Union[str, FieldKey], Any] = { | |
"number_of_particles": number_of_particles | |
} | |
- for key in list(data.keys()): | |
- if len(data[key].shape) == 1 or key[0] == "io": | |
+ for key in in_mem_fields_shape.keys(): | |
+ if len(in_mem_fields_shape[key]) == 1 or key[0] == "io": | |
field: FieldKey | |
if not isinstance(key, tuple): | |
field = ("io", key) | |
@@ -313,15 +319,15 @@ def load_uniform_grid( | |
particle_types = {} | |
if nprocs > 1: | |
- temp = {} | |
- new_data = {} # type: ignore [var-annotated] | |
+ temp: Dict[AnyFieldKey, np.ndarray] = {} | |
+ new_data: Dict[int, Dict[AnyFieldKey, np.ndarray]] = {} | |
for key in data.keys(): | |
- psize = get_psize(np.array(data[key].shape), nprocs) | |
+ psize = get_psize(np.array(in_mem_fields_shape[key]), nprocs) | |
grid_left_edges, grid_right_edges, shapes, slices = decompose_array( | |
- data[key].shape, psize, bbox | |
+ in_mem_fields_shape[key], psize, bbox | |
) | |
grid_dimensions = np.array([shape for shape in shapes], dtype="int32") | |
- temp[key] = [data[key][slice] for slice in slices] | |
+ temp[key] = [data[key][slice] for slice in slices] # type: ignore | |
for gid in range(nprocs): | |
new_data[gid] = {} | |
for key in temp.keys(): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment