Last active
October 27, 2024 16:33
-
-
Save prerakmody/def4c3be07f6a69e87a0cd3a53616cb4 to your computer and use it in GitHub Desktop.
Scribble Generation (in an auto segmentation editing scenario)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import skimage | |
import skimage.morphology | |
import voxynth as voxynth # https://github.com/dalcalab/voxynth | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def getContourScribbleForBinarySlice(mask, show=False): | |
""" | |
Params | |
------ | |
mask: [H,W], np.ndarray, containing 1s and 0s | |
Returns | |
------- | |
contourScribbleMask: [H,W], np.ndarray, containing 1s and 0s | |
contourScribbleMaskFinal: [H,W], np.ndarray, containing 1s and 0s | |
""" | |
contourScribbleMaskFinal = None | |
try: | |
# Step - Init | |
# contourScribbler = ContourScribble() | |
pass | |
# Step 1 - Erode mask (FP/FN) NOTE: You do not want to draw scribbles on the error masks contours itself. Distance map will thicken it. | |
maskErode = skimage.morphology.erosion(mask.astype(np.uint8)) | |
maskErode2 = skimage.morphology.erosion(maskErode) | |
# Step 2 - Get edge | |
# maskErodeEdge = skimage.feature.canny(maskErode, sigma=3) | |
maskErodeEdge = maskErode - maskErode2 | |
# Step 3 - Get noise mask (i.e. perlin-noise) | |
noiseMask = voxynth.noise.perlin(shape=(maskErodeEdge.shape[0], maskErodeEdge.shape[1]), smoothing=np.random.uniform(2,14), magnitude=1).cpu().numpy() > 0 | |
# plt.imshow(noiseMask); plt.title('smoothing=2'); plt.show() | |
# Step 5 - Get interaction mask (i.e. noise mask * edge mask) | |
contourScribbleMask = noiseMask * maskErodeEdge | |
# Step 5.1 - If no interaction mask, then use noise mask (again) | |
if np.sum(contourScribbleMask) == 0: | |
noiseMask2 = voxynth.noise.perlin(shape=(maskErodeEdge.shape[0], maskErodeEdge.shape[1]), smoothing=np.random.uniform(2,14), magnitude=1).cpu().numpy() > 0 | |
contourScribbleMask = noiseMask2 * maskErodeEdge | |
if np.sum(contourScribbleMask) == 0: | |
contourScribbleMask = maskErodeEdge | |
contourScribbleMask = np.round(contourScribbleMask).astype(np.uint8) | |
# Step 6 - Warp (the scribble) | |
# deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(10,20), warp_magnitude_range=(1,2), isdisp=False) # default | |
# deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(4,16), warp_magnitude_range=(1,6), isdisp=False) # ScribblePrompt | |
deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(10,16), warp_magnitude_range=(1,1), isdisp=False) # Me (in contour scribbles, you cant have so much variation along the contour) | |
contourScribbleMaskWarped = voxynth.transform.spatial_transform(torch.from_numpy(contourScribbleMask).unsqueeze(0), trf = deformation_field, isdisp=False).cpu().numpy()[0] | |
contourScribbleMaskWarpedSkel = skimage.morphology.skeletonize(contourScribbleMaskWarped > 0) | |
contourScribbleMaskFinal = np.round(contourScribbleMaskWarpedSkel).astype(np.uint8) | |
# Step 99 - Plot | |
if show: | |
f,axarr = plt.subplots(1,6) | |
axarr[0].imshow(mask); axarr[0].set_title('mask') | |
axarr[1].imshow(mask); axarr[1].imshow(maskErode, alpha=0.5); axarr[1].set_title('maskErode') | |
axarr[2].imshow(mask); axarr[2].imshow(maskErodeEdge, alpha=0.5); axarr[2].set_title('maskErodeEdge') | |
axarr[3].imshow(noiseMask); axarr[3].imshow(maskErodeEdge, alpha=0.5); axarr[3].set_title('noiseMask') | |
axarr[4].imshow(mask); axarr[4].imshow(contourScribbleMask, alpha=0.5); axarr[4].set_title('contourScribbleMask') | |
axarr[5].imshow(mask); axarr[5].imshow(contourScribbleMaskFinal, alpha=0.5); axarr[5].set_title('contourScribbleMaskFinal') | |
plt.show() | |
pdb.set_trace() | |
except: | |
traceback.print_exc() | |
pdb.set_trace() | |
return contourScribbleMask, contourScribbleMaskFinal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
KEY_SCRIBBLE_MEDIAL_AXIS = 'scribble_medial_axis' | |
KEY_SCRIBBLE_SKELETONIZE = 'scribble_skeletonize' | |
def getMaskForLabel(arrayMask, label): | |
""" | |
Params | |
------ | |
arrayMask: [H,W,Depth], np.ndarray, containing 1s and 0s | |
label: int, label for which mask is to be generated | |
""" | |
arrayMaskOfLabel = None | |
try: | |
arrayMaskOfLabel = np.zeros(arrayMask.shape) | |
arrayMaskOfLabel[arrayMask == label] = 1 | |
except: | |
traceback.print_exc() | |
pdb.set_trace() | |
return arrayMaskOfLabel | |
def getMorphologyScribbleForBinarySlice(binaryMaskSlice, scribbleType, valMode=False, verbose=False): | |
""" | |
Applies skimage.morphology.{medial_axis, skeletonize} on the binaryMaskSlice | |
Params | |
------ | |
binaryMaskSlice: [H,W], np.ndarray, containing 1s and 0s | |
Returns | |
------- | |
binaryInteractionMaskSlice: [H,W], np.ndarray, containing 1s and 0s | |
binaryInteractionMaskSliceBroken: [H,W], np.ndarray, containing 1s and 0s | |
""" | |
binaryInteractionMaskSlice = None | |
binaryInteractionMaskSliceBroken = None | |
try: | |
# Step 0 - Init | |
def randomBreakSkeleton(skeleton, num_breaks=3): | |
coords = np.column_stack(np.where(skeleton)) | |
for _ in range(num_breaks): | |
y, x = random.choice(coords) | |
skeleton[y, x] = 0 | |
return skeleton | |
# Step 1 - Get medial axis / skeleton | |
if scribbleType == KEY_SCRIBBLE_MEDIAL_AXIS: | |
# binaryInteractionMaskSlice = skimage.morphology.medial_axis(np.ascontiguousarray(binaryMaskSlice)) | |
binaryInteractionMaskSlice = skimage.morphology.medial_axis(binaryMaskSlice) | |
elif scribbleType == KEY_SCRIBBLE_SKELETONIZE: | |
binaryInteractionMaskSlice = skimage.morphology.skeletonize(binaryMaskSlice) | |
else: | |
print (' --- [INFO][getMedialAxisPointsForBinarySlice()] Invalid scribbleType: ' + scribbleType) | |
return binaryInteractionMaskSlice | |
# Step 2 - Try to break/randomize it | |
try: | |
chanceVal = random.random() | |
if valMode: | |
chanceVal = 1.0 # no breakage for validation | |
# Step 2.1 - Break the medial-axis/skeleton, find connected components and select a random one | |
if chanceVal < 0.33: | |
skeletonizedMaskSliceBroken = randomBreakSkeleton(binaryInteractionMaskSlice.copy()) | |
skeletonizedMaskSliceBrokenComponents, skeletonizedMaskSliceBrokenComponentCount = skimage.measure.label(skeletonizedMaskSliceBroken, return_num=True) | |
if skeletonizedMaskSliceBrokenComponentCount == 0: | |
binaryInteractionMaskSliceBroken = binaryInteractionMaskSlice | |
if verbose: | |
print (' --- [INFO][getMedialAxisPointsForBinarySlice()] No components found in broken skeleton: np.sum(binaryMaskSlice)={}, np.sum(binaryInteractionMaskSlice)={}, '.format(np.sum(binaryMaskSlice), np.sum(binaryInteractionMaskSlice))) | |
else: | |
binaryInteractionMaskSliceBroken = getMaskForLabel(skeletonizedMaskSliceBrokenComponents, random.randint(1, skeletonizedMaskSliceBrokenComponentCount)) | |
# Step 2.2 - OR keep only x% of the points (to emulate scribble as a collection of points) | |
elif chanceVal >= 0.33 and chanceVal < 0.66: | |
binaryInteractionPoints = np.argwhere(binaryInteractionMaskSlice) | |
random.shuffle(binaryInteractionPoints) | |
binaryInteractionPoints = binaryInteractionPoints[:int(np.random.uniform(0.3,1)*len(binaryInteractionPoints))] # keep only x% of the points | |
binaryInteractionMaskSlice = np.zeros_like(binaryMaskSlice) | |
for point in binaryInteractionPoints: | |
binaryInteractionMaskSlice[point[0], point[1]] = 1 | |
# Step 2.3 - OR keep the medial-axis/skeleton as is | |
else: | |
binaryInteractionMaskSliceBroken = binaryInteractionMaskSlice | |
except: | |
print (' - [getMedialAxisPointsForBinarySlice()] Error in breaking the skeleton') | |
# plt.imshow(binaryErrorSlice); plt.imshow(skeletonizedMaskSlice, alpha=0.5); plt.show(block=False) | |
# pdb.set_trace() | |
except: | |
traceback.print_exc() | |
pdb.set_trace() | |
return binaryInteractionMaskSlice, binaryInteractionMaskSliceBroken # broken is used for final distance map |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment