Created
June 14, 2023 08:58
-
-
Save petebankhead/b0d7d50b6eb99372a4fcacc363c7b2f5 to your computer and use it in GitHub Desktop.
QuPath script to threshold pixel classification outputs with fixed (non-default) threshold values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* QuPath v0.4.3 script to threshold probability outputs from a pixel classifier. | |
* | |
* By default, converting the pixel classifier output to probabilities will always | |
* apply a softmax operation and take the class with the highest probability. | |
* | |
* This script provides an alternative, whereby you can specify a probability threshold | |
* for one or more channels, and threshold to create objects from that. | |
* Using a higher threshold can then restrict the object creation of more confident predictions. | |
* | |
* Achieving this is a bit of a hack, since QuPath's current design doesn't make it very | |
* easy to do at all... it's possible in this script, but it accesses (and modifies) a private | |
* variable so it might break in future versions. | |
* | |
* Note that this script also supports selected objects, so you can restrict the object creation | |
* to just part of the image. But because it deletes all child objects by default, be careful not | |
* to inadvertently delete things you don't want to lose. | |
* | |
* Written initially for https://forum.image.sc/t/qupath-pixel-classifier-probability-threshold/82190 | |
* | |
* @author Pete Bankhead | |
*/ | |
import org.locationtech.jts.geom.Geometry | |
import qupath.lib.analysis.images.ContourTracing | |
import qupath.lib.images.servers.ImageServerMetadata | |
import qupath.lib.images.servers.PixelType | |
import qupath.lib.objects.PathObjects | |
import qupath.lib.objects.classes.PathClass | |
import qupath.lib.regions.ImagePlane | |
import qupath.lib.regions.RegionRequest | |
import qupath.lib.roi.GeometryTools | |
import qupath.opencv.ml.pixel.PixelClassifierTools | |
import static qupath.lib.gui.scripting.QPEx.* | |
// Define classifier name & required channels to threshold | |
String classifierName = "4-class pixel classifier" | |
def channelNames = ["Tumor", "Stroma", "Necrosis"] | |
double probabilityThreshold = 0.8 | |
// Create an ImageServer from the pixel classifier and check it gives probabilities as output | |
def classifier = loadPixelClassifier(classifierName) | |
def imageData = getCurrentImageData() | |
def server = PixelClassifierTools.createPixelClassificationServer(imageData, classifier) | |
if (server.getMetadata().getChannelType() != ImageServerMetadata.ChannelType.PROBABILITY) { | |
println "Channel type is not probability!" | |
return | |
} | |
// Hack! Setting metadata for a probability image server isn't supported, so we need to change the private field | |
// Groovy allows this (although Java would try to stop us) | |
def metadata = new ImageServerMetadata.Builder(server.getMetadata()) | |
.channelType(ImageServerMetadata.ChannelType.DEFAULT) | |
.build() | |
server.originalMetadata = metadata | |
// Get the selected objects, or default to the root object (to threshold the whole image) | |
def selectedObjects = getSelectedObjects() as List | |
if (selectedObjects.isEmpty()) | |
selectedObjects = [imageData.getHierarchy().getRootObject()] | |
// Loop through the selected objects | |
for (def selected : selectedObjects) { | |
selected.clearChildObjects() // Remove any existing objects (delete this line if you don't want this!) | |
// Identify where we need to threshold | |
def roi = selected.getROI() | |
def plane = ImagePlane.getDefaultPlane() | |
Geometry clipMask = null | |
RegionRequest request = RegionRequest.createInstance(server) | |
if (roi != null) { | |
clipMask = roi.getGeometry() | |
plane = roi.getImagePlane() | |
request = RegionRequest.createInstance(server.getPath(), server.getDownsampleForResolution(0), roi) | |
} | |
// Determine thresholds | |
double minThreshold = probabilityThreshold | |
if (server.getPixelType() == PixelType.UINT8) | |
minThreshold = probabilityThreshold * 255 // Minimum threshold - probability scaled for 8-bit images | |
double maxThreshold = Double.POSITIVE_INFINITY // No maximum threshold | |
// Loop through the channels to threshold | |
for (def name : channelNames) { | |
if (Thread.currentThread().interrupted) { | |
println "Interrupted!" | |
return | |
} | |
int channel = server.getMetadata().getChannels().collect(c -> c.name).indexOf(name) | |
if (channel < 0) { | |
println "Channel not found: $name" | |
continue | |
} | |
// Trace a contours and add a new annotation | |
def geom = ContourTracing.traceGeometry(server, request, clipMask, channel, minThreshold, maxThreshold) | |
if (geom != null) { | |
def annotation = PathObjects.createAnnotationObject( | |
GeometryTools.geometryToROI(geom, plane), | |
PathClass.fromString(name) | |
) | |
annotation.setLocked(true) // Remove this if you don't want to lock the new annotation | |
selected.addChildObject(annotation) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment