Skip to content

Instantly share code, notes, and snippets.

@lassoan
Last active August 31, 2023 10:25
Show Gist options
  • Save lassoan/428af5285da75dc033d32ebff65ba940 to your computer and use it in GitHub Desktop.
Save lassoan/428af5285da75dc033d32ebff65ba940 to your computer and use it in GitHub Desktop.
Training data augmentation with random volume translation, rotation, and deformation
# This script randomly warps a 3D volume and adds random translations, rotations,
# and save each resulting 3D volume (and a screenshot for quick overview)
#
# The script can be executed by copy-pasting into 3D Slicer's Python console
# or in a Jupyter notebook running 3D Slicer kernel (provided by SlicerJupyter extension).
#
# Prerequisites:
# - Recent Slicer-4.11 version
# - SlicerIGT extension installed (for random deformations)
import SampleData
import ScreenCapture
import numpy as np
import os
#############################
# Set inputs
slicer.mrmlScene.Clear()
# Set input volume that will be deformed. To load your own volume from file, use:
# slicer.util.loadVolume("c:/path/to/myvolume.nrrd")
volumeNode = SampleData.SampleDataLogic().downloadMRBrainTumor1()
# Load sample segmentation node
segmentationNode = slicer.util.loadSegmentation(SampleData.downloadFromURL(fileNames='MRBrainTumor1.seg.nrrd',
loadFileTypes=['SegmentationFile'],
uris='https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310',
checksums='SHA256:dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310')[0])
# Convert segmentation node to label volume node
labelVolumeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLabelMapVolumeNode')
slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(segmentationNode, labelVolumeNode)
slicer.mrmlScene.RemoveNode(segmentationNode)
# Set output parameters
numberOfOutputVolumesToCreate = 24
outputVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d.nrrd"
outputLabelVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d-label.nrrd"
outputScreenshotsFilenamePattern = slicer.app.temporaryPath+"/screenshots/transformedVolume_%04d.png"
# Transformation parameters
translationStDev = 10.0
rotationDegStDev = 5.0
warpingControlPointsSpacing = 70
warpingDisplacementStdDev = 5.0
#############################
# Processing
# Create output folders
for filepath in [outputVolumeFilenamePattern, outputLabelVolumeFilenamePattern, outputScreenshotsFilenamePattern]:
filedir = os.path.dirname(filepath)
if not os.path.exists(filedir):
os.makedirs(filedir)
# Set up warping transform computation
pointsFrom = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsFrom")
pointsFrom.SetLocked(True)
pointsFrom.GetDisplayNode().SetPointLabelsVisibility(False)
pointsFrom.GetDisplayNode().SetSelectedColor(0,1,0)
pointsTo = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsTo")
pointsTo.GetDisplayNode().SetPointLabelsVisibility(False)
volumeBounds=[0,0,0,0,0,0]
volumeNode.GetBounds(volumeBounds)
warpingTransformNode = None
if hasattr(slicer.modules, "fiducialregistrationwizard"):
warpingTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "WarpingTransform")
fidReg = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLFiducialRegistrationWizardNode")
fidReg.SetRegistrationModeToWarping()
fidReg.SetAndObserveFromFiducialListNodeId(pointsFrom.GetID())
fidReg.SetAndObserveToFiducialListNodeId(pointsTo.GetID())
fidReg.SetOutputTransformNodeId(warpingTransformNode.GetID())
else:
slicer.util.errorDisplay("SlicerIGT extension is required for applying warping transform")
# Set up linear transform computation
fullTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "FullTransform")
fullTransformNode.SetAndObserveMatrixTransformToParent(vtk.vtkMatrix4x4())
#fullTransformNode.AddDefaultStorageNode()
volumeNode.SetAndObserveTransformNodeID(fullTransformNode.GetID())
# Set up transformation chain: volume is warped, then translated&rotated
transformedVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode")
parameters = {
"inputVolume": volumeNode.GetID(),
"outputVolume": transformedVolumeNode.GetID(),
"referenceVolume": volumeNode.GetID(),
"transformationFile": fullTransformNode.GetID()}
# Initial resampling (without transformation)
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, None, parameters)
# Set up visualization for screenshots
slicer.app.layoutManager().setLayout(slicer.vtkMRMLLayoutNode.SlicerLayoutFourUpView)
slicer.util.setSliceViewerLayers(background=transformedVolumeNode, fit=True)
pointsFrom.GetDisplayNode().SetVisibility(False)
pointsTo.GetDisplayNode().SetVisibility(False)
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor(0,0,0)
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor2(0,0,0)
# Volume rendering
volRenLogic = slicer.modules.volumerendering.logic()
displayNode = volRenLogic.CreateDefaultVolumeRenderingNodes(transformedVolumeNode)
displayNode.SetVisibility(True)
scalarRange = transformedVolumeNode.GetImageData().GetScalarRange()
if scalarRange[1]-scalarRange[0] < 1500:
# small dynamic range, probably MRI
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('MR-Default'))
else:
# larger dynamic range, probably CT
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('CT-Chest-Contrast-Enhanced'))
# Generate as many deformed volumes as requested
for outputVolumeIndex in range(numberOfOutputVolumesToCreate):
# translation and rotation
fullTransform = vtk.vtkGeneralTransform()
if warpingTransformNode:
# warping
controlPointCoordsSplit = np.mgrid[
volumeBounds[0]:volumeBounds[1]:warpingControlPointsSpacing,
volumeBounds[2]:volumeBounds[3]:warpingControlPointsSpacing,
volumeBounds[4]:volumeBounds[5]:warpingControlPointsSpacing]
controlPointCoords = np.vstack([controlPointCoordsSplit[0].ravel(), controlPointCoordsSplit[1].ravel(), controlPointCoordsSplit[2].ravel()]).T
controlPointDisplacement = np.random.normal(0, warpingDisplacementStdDev, size=controlPointCoords.shape)
slicer.util.updateMarkupsControlPointsFromArray(pointsFrom, controlPointCoords)
slicer.util.updateMarkupsControlPointsFromArray(pointsTo, controlPointCoords + controlPointDisplacement)
fullTransform.Concatenate(warpingTransformNode.GetTransformFromParent())
fullTransform.Translate(np.random.normal(0, translationStDev, 3))
fullTransform.RotateX(np.random.normal(0, rotationDegStDev))
fullTransform.RotateY(np.random.normal(0, rotationDegStDev))
fullTransform.RotateZ(np.random.normal(0, rotationDegStDev))
fullTransformNode.SetAndObserveTransformFromParent(fullTransform)
# Compute transformed label volume and save to file
parameters["inputVolume"] = labelVolumeNode.GetID()
parameters["interpolationType"] = "nn" # nearest neighbor to preserve label values
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters)
# Save result volume
outputFilename = outputLabelVolumeFilenamePattern % outputVolumeIndex
print("Save transformed label volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename))
success = slicer.util.saveNode(transformedVolumeNode, outputFilename)
# Compute transformed volume and save to file
parameters["inputVolume"] = volumeNode.GetID()
parameters["interpolationType"] = "linear"
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters)
# Save result volume
outputFilename = outputVolumeFilenamePattern % outputVolumeIndex
print("Save transformed volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename))
success = slicer.util.saveNode(transformedVolumeNode, outputFilename)
# Save result screenshot
cap = ScreenCapture.ScreenCaptureLogic()
cap.showViewControllers(False)
outputFilename = outputScreenshotsFilenamePattern % outputVolumeIndex
cap.captureImageFromView(None, outputFilename)
cap.showViewControllers(True)
# Create gallery view of all augmented images
cap.createLightboxImage(8,
os.path.dirname(outputScreenshotsFilenamePattern),
os.path.basename(outputScreenshotsFilenamePattern),
numberOfOutputVolumesToCreate,
os.path.dirname(outputScreenshotsFilenamePattern)+"/gallery.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment