Created
April 19, 2024 09:42
-
-
Save prerakmody/2db12ff4914d7c322dbae837d584c8be to your computer and use it in GitHub Desktop.
Stochasticaly Varying Spatial Smoothing (SVLS)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_svls_filter_3d(kernel_size=3, sigma=1, verbose=False): | |
""" | |
Ref: https://github.com/mobarakol/SVLS/blob/main/svls.py (pytorch) | |
- Alternative (for gauss kernel): https://gist.github.com/blzq/c87d42f45a8c5a53f5b393e27b1f5319 | |
Note: group parameter in Conv3D is giving an issue in tf==2.10.0 on Unix | |
- "tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to launch ptxas'" | |
""" | |
try: | |
# Step 1 - Create a x, y, z coordinate grid of shape (kernel_size, kernel_size, kernel_size, 3) | |
x_coord = tf.range(kernel_size) # [3] | |
x_grid_2d = tf.tile(x_coord, [kernel_size]) # [3*3=9] | |
x_grid_2d = tf.reshape(x_grid_2d, (kernel_size, kernel_size)) # [3,3] | |
x_grid = tf.tile(x_grid_2d, [kernel_size, 1]) # [3,3] -> [3*3,3] | |
x_grid = tf.reshape(x_grid, (kernel_size, kernel_size, kernel_size)) # [3,3,3] | |
y_grid_2d = tf.transpose(x_grid_2d) # [3,3] | |
y_grid = tf.tile(y_grid_2d, [kernel_size, 1]) # [3,3] -> [3*3,3] | |
y_grid = tf.reshape(y_grid, (kernel_size, kernel_size, kernel_size)) # [3,3,3] | |
z_grid = tf.tile(y_grid_2d, [1, kernel_size]) | |
z_grid = tf.reshape(z_grid, (kernel_size, kernel_size, kernel_size)) | |
xyz_grid = tf.stack([x_grid, y_grid, z_grid], axis=-1) | |
xyz_grid = tf.cast(xyz_grid, tf.float32) | |
# Step 2 - Calculate the 3-dimensional gaussian kernel | |
mean = (kernel_size - 1) / 2. | |
variance = sigma**2. | |
gaussian_kernel = (1. / (2. * math.pi * variance + 1e-16)) * tf.exp(-tf.reduce_sum((xyz_grid - mean)**2., axis=-1) / (2 * variance + 1e-16)) | |
# Step 3 - Make sure sum of values in gaussian kernel equals 2 | |
gaussian_kernel = gaussian_kernel / tf.reduce_sum(gaussian_kernel) | |
neighbors_sum = 1 - gaussian_kernel[1, 1, 1] | |
# Step 3.1 - Need to do this in tensorflow as you cant do tensor assignment | |
indices = tf.constant([[1, 1, 1]]) | |
updates = tf.constant([1], dtype=gaussian_kernel.dtype) | |
shape = tf.constant([kernel_size, kernel_size, kernel_size]) | |
tensor_neighbours_sum = tf.scatter_nd(indices, updates, shape) # given tensor of shape [3,3,3] and will set tensor_neighbours_sum[1,1,1] = 1 | |
gaussian_kernel = tensor_neighbours_sum - gaussian_kernel | |
svls_kernel_3d = tf.abs(gaussian_kernel / neighbors_sum) | |
if verbose: print (' - [losses2.py][get_svls_filter_3d()] svls_kernel_3d: ', svls_kernel_3d) | |
# Step 4 - Make convolutional layer and set weights (applied on each label separately, since group parameter is giving an runtime error) | |
svls_kernel_3d = tf.reshape(svls_kernel_3d, (kernel_size, kernel_size, kernel_size, 1, 1)) | |
svls_filter_3d = tf.keras.layers.Conv3D(filters=1, kernel_size=kernel_size, use_bias=False, padding='same') | |
svls_filter_3d.build((None, None, None, None, 1)) | |
svls_filter_3d.set_weights([svls_kernel_3d]) | |
svls_filter_3d.trainable = False | |
return svls_filter_3d, svls_kernel_3d[:,:,:,:,0] # [3,3,3,1] | |
except: | |
traceback.print_exc() | |
# pdb.set_trace() | |
return None, None | |
class CELossWithSVLS(tf.keras.Model): | |
def __init__(self, classes, sigma): | |
super(CELossWithSVLS, self).__init__() | |
print (' - [losses2.py][CELossWithSVLS] Using sigma=', sigma, ' for classes=', classes) | |
self.cls = tf.constant(classes) | |
self.sigma = sigma | |
self.svls_layer, self.svls_kernel = get_svls_filter_3d(sigma=sigma) # self.svls_kernel = [3,3,3,1] | |
# @tf.function(jit_compile=config.JIT_COMPILE) | |
def call(self, y_true, y_pred, label_mask, weights): | |
try: | |
# Step 0 - Init | |
loss_labels = [] | |
label_mask = tf.cast(label_mask, dtype=tf.float32) | |
# Step 0.1 - Label (y_true) Smoothing | |
if 1: | |
labelCount = tf.cast(tf.math.reduce_sum(tf.ones_like(weights)), tf.int32) | |
# Step 0.1 - Trial 1 | |
def process_classID(classID): | |
return self.svls_layer(tf.expand_dims(y_true[:,:,:,:,classID], axis=-1)) | |
yTrueSmoothed = tf.map_fn(process_classID, tf.range(labelCount), fn_output_signature=tf.float32) # [B,H,W,D,L] --> [L,B,H,W,D,1] | |
yTrueSmoothed = tf.transpose(yTrueSmoothed, perm=[1, 2, 3, 4, 0, 5])[:,:,:,:,:,0] # [L,B,H,W,D,1] --> [B,H,W,D,L] | |
# yTrueSmoothed = []; for classID in range(labelCount): yTrueSmoothed.append(self.svls_layer(tf.expand_dims(y_true[:,:,:,:,classID], axis=-1))); yTrueSmoothed = tf.concat(yTrueSmoothed, axis=-1) # in pdb | |
yTrueSmoothed = yTrueSmoothed / tf.math.reduce_sum(self.svls_kernel) | |
if 1: | |
print (' - [losses.py][CELossWithSVLS][batch=0] labels with GT: ',np.sum(y_true[0,:,:,:,:], axis=(0,1,2))) | |
print (' - [losses.py][CELossWithSVLS][batch=1] labels with GT: ',np.sum(y_true[1,:,:,:,:], axis=(0,1,2))) | |
batchId = np.random.choice([0,1]) | |
labelId = np.random.choice(np.argwhere(np.sum(y_true[batchId,:,:,:,:], axis=(0,1,2))).flatten()) | |
sliceId = np.random.choice(np.argwhere(np.sum(y_true[batchId,:,:,:,labelId], axis=(0,1))).flatten()) | |
cmap = 'Oranges' | |
f,axarr = plt.subplots(3,2) | |
plt.suptitle(' - [losses.py][CELossWithSVLS] batchId: '+str(batchId)+' || labelId: '+str(labelId)+' || sliceId: '+str(sliceId) + ' || sigma: '+str(self.sigma)) | |
axarr[0,0].imshow(y_true[batchId,:,:,sliceId-1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[0,0].set_title('sliceId-1: ' + str(sliceId-1)) | |
axarr[1,0].imshow(y_true[batchId,:,:,sliceId ,labelId], vmin=0, vmax=1, cmap=cmap); axarr[1,0].set_title('y_true | unique: '+str(np.unique(y_true[batchId,:,:,sliceId,labelId]))) | |
axarr[2,0].imshow(y_true[batchId,:,:,sliceId+1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[2,0].set_title('sliceId+1: ' + str(sliceId+1)) | |
yTrueSmoothedUniqueVals = ['{:.3f}'.format(each) for each in np.unique(yTrueSmoothed[batchId,:,:,sliceId,labelId])][:4] | |
axarr[0,1].imshow(yTrueSmoothed[batchId,:,:,sliceId-1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[0,1].set_title('sliceId-1: ' + str(sliceId-1)) | |
axarr[1,1].imshow(yTrueSmoothed[batchId,:,:,sliceId ,labelId], vmin=0, vmax=1, cmap=cmap); axarr[1,1].set_title('y_true_smoothed | unique: '+str(yTrueSmoothedUniqueVals) + ' ... ') | |
axarr[2,1].imshow(yTrueSmoothed[batchId,:,:,sliceId+1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[2,1].set_title('sliceId+1: ' + str(sliceId+1)) | |
plt.show(block=False) | |
pdb.set_trace() | |
# Step 1.1 - Foreground loss | |
loss_labels_pos = -1.0 * yTrueSmoothed * tf.math.log(y_pred + config._EPSILON) # [B,H,W,D,L] | |
loss_labels_pos = label_mask * tf.math.reduce_sum(loss_labels_pos, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L] | |
# Step 1.2 - Background loss | |
loss_labels_neg = -1.0 * (1 - yTrueSmoothed) * tf.math.log(1 - y_pred + config._EPSILON) # [B,H,W,D,L] | |
loss_labels_neg = label_mask * tf.math.reduce_sum(loss_labels_neg, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L] | |
loss_labels = loss_labels_pos + loss_labels_neg # [B,L] | |
# Step 2 - Mask results on the basis of ground truth availability | |
label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, config._EPSILON) # for reasons of division | |
loss_for_train = None | |
loss_labels_for_train = None | |
loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0) # [B,L] -> [L], [B,L] -> [L], [L]/[L] = [L] (average of labels across batches) | |
loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) # [B,L] -> [B], [B,L] -> [B], mean([B]) -> [1] (Average across batches of sum of labels) | |
# Step 3 - Weighted DICE | |
if len(weights): | |
label_weights = weights / tf.math.reduce_sum(weights) # normalized | |
loss_labels_w = loss_labels * label_weights # [B,L] | |
loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0) # [L] | |
loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) # [1] | |
else: | |
loss_labels_for_train = loss_labels_for_report | |
loss_for_train = loss_for_report | |
# Step 4 - Return results | |
return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report | |
except: | |
traceback.print_exc() | |
# pdb.set_trace() | |
return None, None, None, None | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
_, tmp1 = get_svls_filter_3d(sigma=1) | |
_, tmp2 = get_svls_filter_3d(sigma=2) | |
_, tmp3 = get_svls_filter_3d(sigma=3) | |
_, tmp4 = get_svls_filter_3d(sigma=3) | |
f,axarr = plt.subplots(2,6) | |
vmin, vmax, cmap = 0, 0.1, 'Oranges' | |
plt.suptitle('3D SVLS (Stochastically varying label smoothing) filter') | |
sns.heatmap(tmp1[0,:,:,0], ax=axarr[0,0], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp1[1,:,:,0], ax=axarr[0,1], vmin=vmin, vmax=vmax, cmap=cmap); axarr[0,1].set_title('sigma=1 (smooth gradient from center to edge)') | |
sns.heatmap(tmp1[2,:,:,0], ax=axarr[0,2], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp2[0,:,:,0], ax=axarr[0,3], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp2[1,:,:,0], ax=axarr[0,4], vmin=vmin, vmax=vmax, cmap=cmap); axarr[0,4].set_title('sigma=2') | |
sns.heatmap(tmp2[2,:,:,0], ax=axarr[0,5], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp3[0,:,:,0], ax=axarr[1,0], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp3[1,:,:,0], ax=axarr[1,1], vmin=vmin, vmax=vmax, cmap=cmap); axarr[1,1].set_title('sigma=3') | |
sns.heatmap(tmp3[2,:,:,0], ax=axarr[1,2], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp4[0,:,:,0], ax=axarr[1,3], vmin=vmin, vmax=vmax, cmap=cmap); | |
sns.heatmap(tmp4[1,:,:,0], ax=axarr[1,4], vmin=vmin, vmax=vmax, cmap=cmap); axarr[1,4].set_title('sigma=4 (sharp gradient from center to edge)') | |
sns.heatmap(tmp4[2,:,:,0], ax=axarr[1,5], vmin=vmin, vmax=vmax, cmap=cmap); | |
for ax in axarr.flatten(): _ = ax.set_xticks([]); _ = ax.set_yticks([]) | |
plt.show(block=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment