Skip to content

Instantly share code, notes, and snippets.

@tfmoraes
Created September 17, 2024 19:12
Show Gist options
  • Save tfmoraes/175b4bfff96c497e2a3838a7ecaabbe7 to your computer and use it in GitHub Desktop.
Save tfmoraes/175b4bfff96c497e2a3838a7ecaabbe7 to your computer and use it in GitHub Desktop.
gen_patches.patch
diff --git a/invesalius/segmentation/deep_learning/segment.py b/invesalius/segmentation/deep_learning/segment.py
index 041e6f90..dc17a203 100644
--- a/invesalius/segmentation/deep_learning/segment.py
+++ b/invesalius/segmentation/deep_learning/segment.py
@@ -5,8 +5,13 @@ import pathlib
import sys
import tempfile
import traceback
+from typing import Generator, Tuple
+import monai
import numpy as np
+from monai.data import decollate_batch
+from monai.inferers import sliding_window_inference
+from monai.transforms import Activations, Compose
from skimage.transform import resize
from vtkmodules.vtkIOXML import vtkXMLImageDataWriter
@@ -21,23 +26,31 @@ from . import utils
SIZE = 48
-
-import monai
-from monai.data import decollate_batch
-from monai.inferers import sliding_window_inference
-from monai.transforms import *
+patch_type = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]
-def gen_patches(image, patch_size, overlap):
+def gen_patches(
+ image: np.ndarray, patch_size: int, overlap: int
+) -> Generator[Tuple[float, np.ndarray, patch_type], None, None]:
overlap = int(patch_size * overlap / 100)
sz, sy, sx = image.shape
- i_cuts = list(
- itertools.product(
- range(0, sz, patch_size - overlap),
- range(0, sy, patch_size - overlap),
- range(0, sx, patch_size - overlap),
- )
- )
+ slices_x = [i for i in range(0, sx, patch_size - overlap) if i + patch_size <= sx]
+ if not slices_x:
+ slices_x.append(0)
+ elif slices_x[-1] + patch_size < sx:
+ slices_x.append(sx - patch_size)
+ slices_y = [i for i in range(0, sy, patch_size - overlap) if i + patch_size <= sy]
+ if not slices_y:
+ slices_y.append(0)
+ elif slices_y[-1] + patch_size < sy:
+ slices_y.append(sy - patch_size)
+ slices_z = [i for i in range(0, sz, patch_size - overlap) if i + patch_size <= sz]
+ if not slices_z:
+ slices_z.append(0)
+ elif slices_z[-1] + patch_size < sz:
+ slices_z.append(sz - patch_size)
+ i_cuts = list(itertools.product(slices_z, slices_y, slices_x))
+
sub_image = np.empty(shape=(patch_size, patch_size, patch_size), dtype="float32")
for idx, (iz, iy, ix) in enumerate(i_cuts):
sub_image[:] = 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment