Created
April 6, 2019 16:11
-
-
Save elistevens/237ce6ce5e5c289a5050b91da8fcd582 to your computer and use it in GitHub Desktop.
3D data augmentation from Deep Learning with PyTorch (untested)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int, maxWidth_mm=32.0, use_cache=True): | |
assert width_mm <= maxWidth_mm | |
if use_cache: | |
cubic_chunk, center_irc = getCtCubicChunk(series_uid, center_xyz, maxWidth_mm) | |
else: | |
ct = getCt(series_uid) | |
ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm) | |
slice_list = [] | |
for axis in range(3): | |
crop_size = cubic_chunk.shape[axis] * width_mm / maxWidth_mm | |
crop_size = int(math.ceil(crop_size)) | |
start_ndx = (cubic_chunk.shape[axis] - crop_size) // 2 | |
end_ndx = start_ndx + crop_size | |
slice_list.append(slice(start_ndx, end_ndx)) | |
cropped_chunk = cubic_chunk[slice_list] | |
# # inflate cropped_chunk back to float32 | |
# cropped_chunk = np.array(cropped_chunk, dtype=np.float32) | |
# cropped_chunk *= clamp_value/255 | |
cropped_tensor = torch.tensor(cropped_chunk).unsqueeze(0).unsqueeze(0) | |
transform_tensor = torch.eye(4).to(torch.float64) | |
# Scale and Mirror | |
for i in range(3): | |
if 'scale' in augmentation_dict: | |
scale_float = augmentation_dict['scale'] | |
transform_tensor[i,i] *= 1.0 - scale_float/2.0 + (random.random() * scale_float) | |
if 'mirror' in augmentation_dict: | |
if random.random() > 0.5: | |
transform_tensor[i,i] *= -1 | |
# Rotate | |
if 'rotate' in augmentation_dict: | |
angle_rad = random.random() * math.pi * 2 | |
s = math.sin(angle_rad) | |
c = math.cos(angle_rad) | |
c1 = 1 - c | |
axis_tensor = torch.rand([3], dtype=torch.float64) | |
axis_tensor /= axis_tensor.pow(2).sum().pow(0.5) | |
z, y, x = axis_tensor | |
rotation_tensor = torch.tensor([ | |
[x*x*c1 + c, y*x*c1 - z*s, z*x*c1 + y*s, 0], | |
[x*y*c1 + z*s, y*y*c1 + c, z*y*c1 - x*s, 0], | |
[x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 + c, 0], | |
[0, 0, 0, 1], | |
], dtype=torch.float64) | |
transform_tensor @= rotation_tensor | |
# Transform into final desired shape | |
affine_tensor = affine_grid_generator( | |
transform_tensor[:3].unsqueeze(0).to(torch.float32), | |
torch.Size([1, 1, voxels_int, voxels_int, voxels_int]) | |
) | |
zoomed_chunk = torch.nn.functional.grid_sample( | |
cropped_tensor, | |
affine_tensor, | |
padding_mode='border' | |
).to('cpu') | |
# Noise | |
if 'noise' in augmentation_dict: | |
noise_tensor = torch.randn( | |
zoomed_chunk.size(), | |
dtype=zoomed_chunk.dtype, | |
) | |
noise_tensor *= augmentation_dict['noise'] | |
zoomed_chunk += noise_tensor | |
return zoomed_chunk[0,0], center_irc | |
# end::cache[] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment