Last active
August 31, 2023 10:25
-
-
Save lassoan/428af5285da75dc033d32ebff65ba940 to your computer and use it in GitHub Desktop.
Training data augmentation with random volume translation, rotation, and deformation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This script randomly warps a 3D volume and adds random translations, rotations, | |
# and save each resulting 3D volume (and a screenshot for quick overview) | |
# | |
# The script can be executed by copy-pasting into 3D Slicer's Python console | |
# or in a Jupyter notebook running 3D Slicer kernel (provided by SlicerJupyter extension). | |
# | |
# Prerequisites: | |
# - Recent Slicer-4.11 version | |
# - SlicerIGT extension installed (for random deformations) | |
import SampleData | |
import ScreenCapture | |
import numpy as np | |
import os | |
############################# | |
# Set inputs | |
slicer.mrmlScene.Clear() | |
# Set input volume that will be deformed. To load your own volume from file, use: | |
# slicer.util.loadVolume("c:/path/to/myvolume.nrrd") | |
volumeNode = SampleData.SampleDataLogic().downloadMRBrainTumor1() | |
# Load sample segmentation node | |
segmentationNode = slicer.util.loadSegmentation(SampleData.downloadFromURL(fileNames='MRBrainTumor1.seg.nrrd', | |
loadFileTypes=['SegmentationFile'], | |
uris='https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310', | |
checksums='SHA256:dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310')[0]) | |
# Convert segmentation node to label volume node | |
labelVolumeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLabelMapVolumeNode') | |
slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(segmentationNode, labelVolumeNode) | |
slicer.mrmlScene.RemoveNode(segmentationNode) | |
# Set output parameters | |
numberOfOutputVolumesToCreate = 24 | |
outputVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d.nrrd" | |
outputLabelVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d-label.nrrd" | |
outputScreenshotsFilenamePattern = slicer.app.temporaryPath+"/screenshots/transformedVolume_%04d.png" | |
# Transformation parameters | |
translationStDev = 10.0 | |
rotationDegStDev = 5.0 | |
warpingControlPointsSpacing = 70 | |
warpingDisplacementStdDev = 5.0 | |
############################# | |
# Processing | |
# Create output folders | |
for filepath in [outputVolumeFilenamePattern, outputLabelVolumeFilenamePattern, outputScreenshotsFilenamePattern]: | |
filedir = os.path.dirname(filepath) | |
if not os.path.exists(filedir): | |
os.makedirs(filedir) | |
# Set up warping transform computation | |
pointsFrom = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsFrom") | |
pointsFrom.SetLocked(True) | |
pointsFrom.GetDisplayNode().SetPointLabelsVisibility(False) | |
pointsFrom.GetDisplayNode().SetSelectedColor(0,1,0) | |
pointsTo = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsTo") | |
pointsTo.GetDisplayNode().SetPointLabelsVisibility(False) | |
volumeBounds=[0,0,0,0,0,0] | |
volumeNode.GetBounds(volumeBounds) | |
warpingTransformNode = None | |
if hasattr(slicer.modules, "fiducialregistrationwizard"): | |
warpingTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "WarpingTransform") | |
fidReg = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLFiducialRegistrationWizardNode") | |
fidReg.SetRegistrationModeToWarping() | |
fidReg.SetAndObserveFromFiducialListNodeId(pointsFrom.GetID()) | |
fidReg.SetAndObserveToFiducialListNodeId(pointsTo.GetID()) | |
fidReg.SetOutputTransformNodeId(warpingTransformNode.GetID()) | |
else: | |
slicer.util.errorDisplay("SlicerIGT extension is required for applying warping transform") | |
# Set up linear transform computation | |
fullTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "FullTransform") | |
fullTransformNode.SetAndObserveMatrixTransformToParent(vtk.vtkMatrix4x4()) | |
#fullTransformNode.AddDefaultStorageNode() | |
volumeNode.SetAndObserveTransformNodeID(fullTransformNode.GetID()) | |
# Set up transformation chain: volume is warped, then translated&rotated | |
transformedVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode") | |
parameters = { | |
"inputVolume": volumeNode.GetID(), | |
"outputVolume": transformedVolumeNode.GetID(), | |
"referenceVolume": volumeNode.GetID(), | |
"transformationFile": fullTransformNode.GetID()} | |
# Initial resampling (without transformation) | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, None, parameters) | |
# Set up visualization for screenshots | |
slicer.app.layoutManager().setLayout(slicer.vtkMRMLLayoutNode.SlicerLayoutFourUpView) | |
slicer.util.setSliceViewerLayers(background=transformedVolumeNode, fit=True) | |
pointsFrom.GetDisplayNode().SetVisibility(False) | |
pointsTo.GetDisplayNode().SetVisibility(False) | |
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor(0,0,0) | |
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor2(0,0,0) | |
# Volume rendering | |
volRenLogic = slicer.modules.volumerendering.logic() | |
displayNode = volRenLogic.CreateDefaultVolumeRenderingNodes(transformedVolumeNode) | |
displayNode.SetVisibility(True) | |
scalarRange = transformedVolumeNode.GetImageData().GetScalarRange() | |
if scalarRange[1]-scalarRange[0] < 1500: | |
# small dynamic range, probably MRI | |
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('MR-Default')) | |
else: | |
# larger dynamic range, probably CT | |
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('CT-Chest-Contrast-Enhanced')) | |
# Generate as many deformed volumes as requested | |
for outputVolumeIndex in range(numberOfOutputVolumesToCreate): | |
# translation and rotation | |
fullTransform = vtk.vtkGeneralTransform() | |
if warpingTransformNode: | |
# warping | |
controlPointCoordsSplit = np.mgrid[ | |
volumeBounds[0]:volumeBounds[1]:warpingControlPointsSpacing, | |
volumeBounds[2]:volumeBounds[3]:warpingControlPointsSpacing, | |
volumeBounds[4]:volumeBounds[5]:warpingControlPointsSpacing] | |
controlPointCoords = np.vstack([controlPointCoordsSplit[0].ravel(), controlPointCoordsSplit[1].ravel(), controlPointCoordsSplit[2].ravel()]).T | |
controlPointDisplacement = np.random.normal(0, warpingDisplacementStdDev, size=controlPointCoords.shape) | |
slicer.util.updateMarkupsControlPointsFromArray(pointsFrom, controlPointCoords) | |
slicer.util.updateMarkupsControlPointsFromArray(pointsTo, controlPointCoords + controlPointDisplacement) | |
fullTransform.Concatenate(warpingTransformNode.GetTransformFromParent()) | |
fullTransform.Translate(np.random.normal(0, translationStDev, 3)) | |
fullTransform.RotateX(np.random.normal(0, rotationDegStDev)) | |
fullTransform.RotateY(np.random.normal(0, rotationDegStDev)) | |
fullTransform.RotateZ(np.random.normal(0, rotationDegStDev)) | |
fullTransformNode.SetAndObserveTransformFromParent(fullTransform) | |
# Compute transformed label volume and save to file | |
parameters["inputVolume"] = labelVolumeNode.GetID() | |
parameters["interpolationType"] = "nn" # nearest neighbor to preserve label values | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters) | |
# Save result volume | |
outputFilename = outputLabelVolumeFilenamePattern % outputVolumeIndex | |
print("Save transformed label volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename)) | |
success = slicer.util.saveNode(transformedVolumeNode, outputFilename) | |
# Compute transformed volume and save to file | |
parameters["inputVolume"] = volumeNode.GetID() | |
parameters["interpolationType"] = "linear" | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters) | |
# Save result volume | |
outputFilename = outputVolumeFilenamePattern % outputVolumeIndex | |
print("Save transformed volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename)) | |
success = slicer.util.saveNode(transformedVolumeNode, outputFilename) | |
# Save result screenshot | |
cap = ScreenCapture.ScreenCaptureLogic() | |
cap.showViewControllers(False) | |
outputFilename = outputScreenshotsFilenamePattern % outputVolumeIndex | |
cap.captureImageFromView(None, outputFilename) | |
cap.showViewControllers(True) | |
# Create gallery view of all augmented images | |
cap.createLightboxImage(8, | |
os.path.dirname(outputScreenshotsFilenamePattern), | |
os.path.basename(outputScreenshotsFilenamePattern), | |
numberOfOutputVolumesToCreate, | |
os.path.dirname(outputScreenshotsFilenamePattern)+"/gallery.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment