Skip to content

Instantly share code, notes, and snippets.

@prerakmody
Last active October 27, 2024 16:33
Show Gist options
  • Save prerakmody/def4c3be07f6a69e87a0cd3a53616cb4 to your computer and use it in GitHub Desktop.
Save prerakmody/def4c3be07f6a69e87a0cd3a53616cb4 to your computer and use it in GitHub Desktop.
Scribble Generation (in an auto segmentation editing scenario)
import skimage
import skimage.morphology
import voxynth as voxynth # https://github.com/dalcalab/voxynth
import numpy as np
import matplotlib.pyplot as plt
def getContourScribbleForBinarySlice(mask, show=False):
"""
Params
------
mask: [H,W], np.ndarray, containing 1s and 0s
Returns
-------
contourScribbleMask: [H,W], np.ndarray, containing 1s and 0s
contourScribbleMaskFinal: [H,W], np.ndarray, containing 1s and 0s
"""
contourScribbleMaskFinal = None
try:
# Step - Init
# contourScribbler = ContourScribble()
pass
# Step 1 - Erode mask (FP/FN) NOTE: You do not want to draw scribbles on the error masks contours itself. Distance map will thicken it.
maskErode = skimage.morphology.erosion(mask.astype(np.uint8))
maskErode2 = skimage.morphology.erosion(maskErode)
# Step 2 - Get edge
# maskErodeEdge = skimage.feature.canny(maskErode, sigma=3)
maskErodeEdge = maskErode - maskErode2
# Step 3 - Get noise mask (i.e. perlin-noise)
noiseMask = voxynth.noise.perlin(shape=(maskErodeEdge.shape[0], maskErodeEdge.shape[1]), smoothing=np.random.uniform(2,14), magnitude=1).cpu().numpy() > 0
# plt.imshow(noiseMask); plt.title('smoothing=2'); plt.show()
# Step 5 - Get interaction mask (i.e. noise mask * edge mask)
contourScribbleMask = noiseMask * maskErodeEdge
# Step 5.1 - If no interaction mask, then use noise mask (again)
if np.sum(contourScribbleMask) == 0:
noiseMask2 = voxynth.noise.perlin(shape=(maskErodeEdge.shape[0], maskErodeEdge.shape[1]), smoothing=np.random.uniform(2,14), magnitude=1).cpu().numpy() > 0
contourScribbleMask = noiseMask2 * maskErodeEdge
if np.sum(contourScribbleMask) == 0:
contourScribbleMask = maskErodeEdge
contourScribbleMask = np.round(contourScribbleMask).astype(np.uint8)
# Step 6 - Warp (the scribble)
# deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(10,20), warp_magnitude_range=(1,2), isdisp=False) # default
# deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(4,16), warp_magnitude_range=(1,6), isdisp=False) # ScribblePrompt
deformation_field = voxynth.transform.random_transform(shape=mask.shape, affine_probability=0, warp_probability=1, warp_integrations=0, warp_smoothing_range=(10,16), warp_magnitude_range=(1,1), isdisp=False) # Me (in contour scribbles, you cant have so much variation along the contour)
contourScribbleMaskWarped = voxynth.transform.spatial_transform(torch.from_numpy(contourScribbleMask).unsqueeze(0), trf = deformation_field, isdisp=False).cpu().numpy()[0]
contourScribbleMaskWarpedSkel = skimage.morphology.skeletonize(contourScribbleMaskWarped > 0)
contourScribbleMaskFinal = np.round(contourScribbleMaskWarpedSkel).astype(np.uint8)
# Step 99 - Plot
if show:
f,axarr = plt.subplots(1,6)
axarr[0].imshow(mask); axarr[0].set_title('mask')
axarr[1].imshow(mask); axarr[1].imshow(maskErode, alpha=0.5); axarr[1].set_title('maskErode')
axarr[2].imshow(mask); axarr[2].imshow(maskErodeEdge, alpha=0.5); axarr[2].set_title('maskErodeEdge')
axarr[3].imshow(noiseMask); axarr[3].imshow(maskErodeEdge, alpha=0.5); axarr[3].set_title('noiseMask')
axarr[4].imshow(mask); axarr[4].imshow(contourScribbleMask, alpha=0.5); axarr[4].set_title('contourScribbleMask')
axarr[5].imshow(mask); axarr[5].imshow(contourScribbleMaskFinal, alpha=0.5); axarr[5].set_title('contourScribbleMaskFinal')
plt.show()
pdb.set_trace()
except:
traceback.print_exc()
pdb.set_trace()
return contourScribbleMask, contourScribbleMaskFinal
import numpy as np
KEY_SCRIBBLE_MEDIAL_AXIS = 'scribble_medial_axis'
KEY_SCRIBBLE_SKELETONIZE = 'scribble_skeletonize'
def getMaskForLabel(arrayMask, label):
"""
Params
------
arrayMask: [H,W,Depth], np.ndarray, containing 1s and 0s
label: int, label for which mask is to be generated
"""
arrayMaskOfLabel = None
try:
arrayMaskOfLabel = np.zeros(arrayMask.shape)
arrayMaskOfLabel[arrayMask == label] = 1
except:
traceback.print_exc()
pdb.set_trace()
return arrayMaskOfLabel
def getMorphologyScribbleForBinarySlice(binaryMaskSlice, scribbleType, valMode=False, verbose=False):
"""
Applies skimage.morphology.{medial_axis, skeletonize} on the binaryMaskSlice
Params
------
binaryMaskSlice: [H,W], np.ndarray, containing 1s and 0s
Returns
-------
binaryInteractionMaskSlice: [H,W], np.ndarray, containing 1s and 0s
binaryInteractionMaskSliceBroken: [H,W], np.ndarray, containing 1s and 0s
"""
binaryInteractionMaskSlice = None
binaryInteractionMaskSliceBroken = None
try:
# Step 0 - Init
def randomBreakSkeleton(skeleton, num_breaks=3):
coords = np.column_stack(np.where(skeleton))
for _ in range(num_breaks):
y, x = random.choice(coords)
skeleton[y, x] = 0
return skeleton
# Step 1 - Get medial axis / skeleton
if scribbleType == KEY_SCRIBBLE_MEDIAL_AXIS:
# binaryInteractionMaskSlice = skimage.morphology.medial_axis(np.ascontiguousarray(binaryMaskSlice))
binaryInteractionMaskSlice = skimage.morphology.medial_axis(binaryMaskSlice)
elif scribbleType == KEY_SCRIBBLE_SKELETONIZE:
binaryInteractionMaskSlice = skimage.morphology.skeletonize(binaryMaskSlice)
else:
print (' --- [INFO][getMedialAxisPointsForBinarySlice()] Invalid scribbleType: ' + scribbleType)
return binaryInteractionMaskSlice
# Step 2 - Try to break/randomize it
try:
chanceVal = random.random()
if valMode:
chanceVal = 1.0 # no breakage for validation
# Step 2.1 - Break the medial-axis/skeleton, find connected components and select a random one
if chanceVal < 0.33:
skeletonizedMaskSliceBroken = randomBreakSkeleton(binaryInteractionMaskSlice.copy())
skeletonizedMaskSliceBrokenComponents, skeletonizedMaskSliceBrokenComponentCount = skimage.measure.label(skeletonizedMaskSliceBroken, return_num=True)
if skeletonizedMaskSliceBrokenComponentCount == 0:
binaryInteractionMaskSliceBroken = binaryInteractionMaskSlice
if verbose:
print (' --- [INFO][getMedialAxisPointsForBinarySlice()] No components found in broken skeleton: np.sum(binaryMaskSlice)={}, np.sum(binaryInteractionMaskSlice)={}, '.format(np.sum(binaryMaskSlice), np.sum(binaryInteractionMaskSlice)))
else:
binaryInteractionMaskSliceBroken = getMaskForLabel(skeletonizedMaskSliceBrokenComponents, random.randint(1, skeletonizedMaskSliceBrokenComponentCount))
# Step 2.2 - OR keep only x% of the points (to emulate scribble as a collection of points)
elif chanceVal >= 0.33 and chanceVal < 0.66:
binaryInteractionPoints = np.argwhere(binaryInteractionMaskSlice)
random.shuffle(binaryInteractionPoints)
binaryInteractionPoints = binaryInteractionPoints[:int(np.random.uniform(0.3,1)*len(binaryInteractionPoints))] # keep only x% of the points
binaryInteractionMaskSlice = np.zeros_like(binaryMaskSlice)
for point in binaryInteractionPoints:
binaryInteractionMaskSlice[point[0], point[1]] = 1
# Step 2.3 - OR keep the medial-axis/skeleton as is
else:
binaryInteractionMaskSliceBroken = binaryInteractionMaskSlice
except:
print (' - [getMedialAxisPointsForBinarySlice()] Error in breaking the skeleton')
# plt.imshow(binaryErrorSlice); plt.imshow(skeletonizedMaskSlice, alpha=0.5); plt.show(block=False)
# pdb.set_trace()
except:
traceback.print_exc()
pdb.set_trace()
return binaryInteractionMaskSlice, binaryInteractionMaskSliceBroken # broken is used for final distance map
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment