Created
September 17, 2024 19:12
-
-
Save tfmoraes/175b4bfff96c497e2a3838a7ecaabbe7 to your computer and use it in GitHub Desktop.
gen_patches.patch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/invesalius/segmentation/deep_learning/segment.py b/invesalius/segmentation/deep_learning/segment.py | |
index 041e6f90..dc17a203 100644 | |
--- a/invesalius/segmentation/deep_learning/segment.py | |
+++ b/invesalius/segmentation/deep_learning/segment.py | |
@@ -5,8 +5,13 @@ import pathlib | |
import sys | |
import tempfile | |
import traceback | |
+from typing import Generator, Tuple | |
+import monai | |
import numpy as np | |
+from monai.data import decollate_batch | |
+from monai.inferers import sliding_window_inference | |
+from monai.transforms import Activations, Compose | |
from skimage.transform import resize | |
from vtkmodules.vtkIOXML import vtkXMLImageDataWriter | |
@@ -21,23 +26,31 @@ from . import utils | |
SIZE = 48 | |
- | |
-import monai | |
-from monai.data import decollate_batch | |
-from monai.inferers import sliding_window_inference | |
-from monai.transforms import * | |
+patch_type = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]] | |
-def gen_patches(image, patch_size, overlap): | |
+def gen_patches( | |
+ image: np.ndarray, patch_size: int, overlap: int | |
+) -> Generator[Tuple[float, np.ndarray, patch_type], None, None]: | |
overlap = int(patch_size * overlap / 100) | |
sz, sy, sx = image.shape | |
- i_cuts = list( | |
- itertools.product( | |
- range(0, sz, patch_size - overlap), | |
- range(0, sy, patch_size - overlap), | |
- range(0, sx, patch_size - overlap), | |
- ) | |
- ) | |
+ slices_x = [i for i in range(0, sx, patch_size - overlap) if i + patch_size <= sx] | |
+ if not slices_x: | |
+ slices_x.append(0) | |
+ elif slices_x[-1] + patch_size < sx: | |
+ slices_x.append(sx - patch_size) | |
+ slices_y = [i for i in range(0, sy, patch_size - overlap) if i + patch_size <= sy] | |
+ if not slices_y: | |
+ slices_y.append(0) | |
+ elif slices_y[-1] + patch_size < sy: | |
+ slices_y.append(sy - patch_size) | |
+ slices_z = [i for i in range(0, sz, patch_size - overlap) if i + patch_size <= sz] | |
+ if not slices_z: | |
+ slices_z.append(0) | |
+ elif slices_z[-1] + patch_size < sz: | |
+ slices_z.append(sz - patch_size) | |
+ i_cuts = list(itertools.product(slices_z, slices_y, slices_x)) | |
+ | |
sub_image = np.empty(shape=(patch_size, patch_size, patch_size), dtype="float32") | |
for idx, (iz, iy, ix) in enumerate(i_cuts): | |
sub_image[:] = 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment