Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save petebankhead/b0d7d50b6eb99372a4fcacc363c7b2f5 to your computer and use it in GitHub Desktop.
Save petebankhead/b0d7d50b6eb99372a4fcacc363c7b2f5 to your computer and use it in GitHub Desktop.
QuPath script to threshold pixel classification outputs with fixed (non-default) threshold values
/**
* QuPath v0.4.3 script to threshold probability outputs from a pixel classifier.
*
* By default, converting the pixel classifier output to probabilities will always
* apply a softmax operation and take the class with the highest probability.
*
* This script provides an alternative, whereby you can specify a probability threshold
* for one or more channels, and threshold to create objects from that.
* Using a higher threshold can then restrict the object creation of more confident predictions.
*
* Achieving this is a bit of a hack, since QuPath's current design doesn't make it very
* easy to do at all... it's possible in this script, but it accesses (and modifies) a private
* variable so it might break in future versions.
*
* Note that this script also supports selected objects, so you can restrict the object creation
* to just part of the image. But because it deletes all child objects by default, be careful not
* to inadvertently delete things you don't want to lose.
*
* Written initially for https://forum.image.sc/t/qupath-pixel-classifier-probability-threshold/82190
*
* @author Pete Bankhead
*/
import org.locationtech.jts.geom.Geometry
import qupath.lib.analysis.images.ContourTracing
import qupath.lib.images.servers.ImageServerMetadata
import qupath.lib.images.servers.PixelType
import qupath.lib.objects.PathObjects
import qupath.lib.objects.classes.PathClass
import qupath.lib.regions.ImagePlane
import qupath.lib.regions.RegionRequest
import qupath.lib.roi.GeometryTools
import qupath.opencv.ml.pixel.PixelClassifierTools
import static qupath.lib.gui.scripting.QPEx.*
// Define classifier name & required channels to threshold
String classifierName = "4-class pixel classifier"
def channelNames = ["Tumor", "Stroma", "Necrosis"]
double probabilityThreshold = 0.8
// Create an ImageServer from the pixel classifier and check it gives probabilities as output
def classifier = loadPixelClassifier(classifierName)
def imageData = getCurrentImageData()
def server = PixelClassifierTools.createPixelClassificationServer(imageData, classifier)
if (server.getMetadata().getChannelType() != ImageServerMetadata.ChannelType.PROBABILITY) {
println "Channel type is not probability!"
return
}
// Hack! Setting metadata for a probability image server isn't supported, so we need to change the private field
// Groovy allows this (although Java would try to stop us)
def metadata = new ImageServerMetadata.Builder(server.getMetadata())
.channelType(ImageServerMetadata.ChannelType.DEFAULT)
.build()
server.originalMetadata = metadata
// Get the selected objects, or default to the root object (to threshold the whole image)
def selectedObjects = getSelectedObjects() as List
if (selectedObjects.isEmpty())
selectedObjects = [imageData.getHierarchy().getRootObject()]
// Loop through the selected objects
for (def selected : selectedObjects) {
selected.clearChildObjects() // Remove any existing objects (delete this line if you don't want this!)
// Identify where we need to threshold
def roi = selected.getROI()
def plane = ImagePlane.getDefaultPlane()
Geometry clipMask = null
RegionRequest request = RegionRequest.createInstance(server)
if (roi != null) {
clipMask = roi.getGeometry()
plane = roi.getImagePlane()
request = RegionRequest.createInstance(server.getPath(), server.getDownsampleForResolution(0), roi)
}
// Determine thresholds
double minThreshold = probabilityThreshold
if (server.getPixelType() == PixelType.UINT8)
minThreshold = probabilityThreshold * 255 // Minimum threshold - probability scaled for 8-bit images
double maxThreshold = Double.POSITIVE_INFINITY // No maximum threshold
// Loop through the channels to threshold
for (def name : channelNames) {
if (Thread.currentThread().interrupted) {
println "Interrupted!"
return
}
int channel = server.getMetadata().getChannels().collect(c -> c.name).indexOf(name)
if (channel < 0) {
println "Channel not found: $name"
continue
}
// Trace a contours and add a new annotation
def geom = ContourTracing.traceGeometry(server, request, clipMask, channel, minThreshold, maxThreshold)
if (geom != null) {
def annotation = PathObjects.createAnnotationObject(
GeometryTools.geometryToROI(geom, plane),
PathClass.fromString(name)
)
annotation.setLocked(true) // Remove this if you don't want to lock the new annotation
selected.addChildObject(annotation)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment