Last active
September 1, 2024 10:15
-
-
Save dacr/8ab38d35575ff6ad8aa9f962f6bf9b87 to your computer and use it in GitHub Desktop.
image classification compare models / published by https://github.com/dacr/code-examples-manager #f1c02983-8800-4c08-812e-14b732be895d/38dde024e768488c8dade3c67442607aee56b94
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// summary : image classification compare models | |
// keywords : djl, machine-learning, tutorial, detection, ai, @testable | |
// publish : gist | |
// authors : David Crosson | |
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2) | |
// id : f1c02983-8800-4c08-812e-14b732be895d | |
// created-on : 2024-01-28T13:24:06+01:00 | |
// managed-by : https://github.com/dacr/code-examples-manager | |
// run-with : scala-cli $file | |
// --------------------- | |
//> using scala "3.4.2" | |
//> using dep "org.slf4j:slf4j-api:2.0.13" | |
//> using dep "org.slf4j:slf4j-simple:2.0.13" | |
//> using dep "net.java.dev.jna:jna:5.14.0" | |
//> using dep "ai.djl:api:0.29.0" | |
//> using dep "ai.djl:basicdataset:0.29.0" | |
//> using dep "ai.djl:model-zoo:0.29.0" | |
//> using dep "ai.djl.huggingface:tokenizers:0.29.0" | |
//> using dep "ai.djl.mxnet:mxnet-engine:0.29.0" | |
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.29.0" | |
//> using dep "ai.djl.pytorch:pytorch-engine:0.29.0" | |
//> using dep "ai.djl.pytorch:pytorch-model-zoo:0.29.0" | |
//> using dep "ai.djl.tensorflow:tensorflow-engine:0.29.0" | |
//> using dep "ai.djl.tensorflow:tensorflow-model-zoo:0.29.0" | |
//> using dep "ai.djl.onnxruntime:onnxruntime-engine:0.29.0" | |
// --------------------- | |
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error") | |
import ai.djl.Application | |
import ai.djl.engine.Engine | |
import ai.djl.repository.Artifact | |
import ai.djl.repository.zoo.{Criteria, ModelNotFoundException, ModelZoo, ModelZooResolver, ZooModel} | |
import ai.djl.training.util.ProgressBar | |
import ai.djl.modality.Classifications | |
import ai.djl.modality.Classifications.Classification | |
import ai.djl.modality.cv.Image | |
import ai.djl.modality.cv.ImageFactory | |
import java.net.{URI, URL} | |
import java.nio.file.Files | |
import java.nio.file.Path | |
import java.nio.file.Paths | |
import java.util.UUID | |
import java.util.concurrent.TimeUnit | |
import scala.concurrent.duration.Duration | |
import scala.jdk.CollectionConverters.* | |
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW} | |
case class ModelArtifact(artifact: Artifact) { | |
val uuid = UUID.nameUUIDFromBytes( | |
s"$groupId$artifactId$version${properties.toList.sorted}".getBytes | |
) | |
def groupId: String = artifact.getMetadata.getGroupId | |
def artifactId: String = artifact.getMetadata.getArtifactId | |
def version: String = artifact.getVersion | |
def properties: Map[String, String] = artifact.getProperties.asScala.toMap | |
def ident = toString() | |
override def toString: String = s"$groupId:$artifactId:$version" | |
} | |
// ---------------------------------------------------------------------------------------------- | |
case class ImageClassification( | |
classification: String, | |
probability: Double | |
) | |
case class ModelResult( | |
inputImageSource: URL, | |
modelArtifact: ModelArtifact, | |
selectedModelPath: Path, | |
responseTime: Duration, | |
imageClassifications: List[ImageClassification] | |
) | |
val blackListed = Set[String]( | |
"ai.djl.pytorch:resnet18_embedding:0.0.1", // ai.djl.translate.TranslateException: java.io.FileNotFoundException: File not found: /home/dcr/.djl.ai/cache/repo/model/cv/image_classification/ai/djl/pytorch/resnet18_embedding/0.0.1/synset.txt | |
"ai.djl.zoo:mlp:0.0.3" // ai.djl.translate.TranslateException: ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: src.Size() % known_dim_size_prod == 0 (672 vs. 0) : Cannot reshape array of size 2056320 into shape [-1,784] | |
) | |
def testModel(modelArtifact: ModelArtifact, inputImageSources: List[URL]): List[ModelResult] = { | |
println(s"${RED}TESTING MODEL $modelArtifact$RESET") | |
val criteria = | |
Criteria | |
.builder() | |
.setTypes(classOf[Image], classOf[Classifications]) | |
.optApplication(Application.CV.IMAGE_CLASSIFICATION) | |
.optGroupId(modelArtifact.groupId) | |
.optArtifactId(modelArtifact.artifactId) | |
.optFilters(modelArtifact.properties.asJava) | |
.optProgress(new ProgressBar) | |
.build() | |
try { | |
val model = ModelZoo.loadModel(criteria) | |
val predictor = model.newPredictor() | |
inputImageSources.map { inputImageSource => | |
val inputImage = ImageFactory.getInstance().fromUrl(inputImageSource) | |
val started = System.currentTimeMillis() | |
val detected: Classifications = predictor.predict(inputImage) | |
val duration = Duration.apply(System.currentTimeMillis() - started, TimeUnit.MILLISECONDS) | |
val imageClassifications = detected | |
.items[Classification]() | |
.asScala | |
.toList | |
.filter(_.getProbability > 0.01) // mandatory as many classes are returned with 0.00 probability | |
.map(detected => ImageClassification(detected.getClassName, detected.getProbability)) | |
ModelResult( | |
inputImageSource = inputImageSource, | |
modelArtifact = modelArtifact, | |
selectedModelPath = model.getModelPath, | |
responseTime = duration, | |
imageClassifications = imageClassifications | |
) | |
} | |
} catch { | |
case err: ModelNotFoundException => | |
println(s"No matching model for $modelArtifact : ${err.getMessage}") | |
Nil | |
} | |
} | |
def showResults(results: Seq[ModelResult]): Unit = { | |
results.groupBy(_.inputImageSource).foreach { (imageURL, resultsForImage) => | |
println(s"${BLUE}${BOLD}==========================================================================$RESET") | |
println(s"${BLUE}${BOLD}RESULTS FOR $imageURL$RESET") | |
resultsForImage.foreach { result => | |
import result.* | |
println(s"${BLUE}${BOLD}--------------------------------------------------------------------------$RESET") | |
println(s"${BLUE}${BOLD}MODEL ${modelArtifact.ident}$RESET") | |
println(s"${BLUE}PATH $selectedModelPath$RESET") | |
println(s"${GREEN}Number of detected image classes : ${imageClassifications.size} in $responseTime$RESET") | |
imageClassifications.sortBy(-_.probability).foreach { detected => | |
println(f" $YELLOW$BOLD${detected.classification} ${RED} ${detected.probability}%1.2f$RESET") | |
} | |
} | |
} | |
} | |
val inputImageSources = | |
1.to(16).toList.map(n => URI.create(f"https://mapland.fr/data/ai/images-samples/example-$n%03d.jpg").toURL) | |
val objectDetectionsArtifacts = | |
ModelZoo | |
.listModels() | |
.asScala | |
.get(Application.CV.IMAGE_CLASSIFICATION) | |
.map(_.asScala) | |
.getOrElse(Nil) | |
.toList | |
val results = | |
objectDetectionsArtifacts | |
.map(ModelArtifact.apply) | |
.filterNot(artifactKey => blackListed.contains(artifactKey.ident)) | |
.flatMap(modelArtifact => testModel(modelArtifact, inputImageSources)) | |
showResults(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment