Skip to content

Instantly share code, notes, and snippets.

@raphaeldelio
Created January 15, 2025 17:51
Show Gist options
  • Select an option

  • Save raphaeldelio/aacc5133f0008d630068ca7d64d35a10 to your computer and use it in GitHub Desktop.

Select an option

Save raphaeldelio/aacc5133f0008d630068ca7d64d35a10 to your computer and use it in GitHub Desktop.
Fine Tuning with DJL example
package dev.raphaeldelio;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.engine.Engine;
import ai.djl.modality.cv.transform.*;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.*;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.*;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.util.Pair;
import ai.djl.training.util.ProgressBar;
import java.nio.file.Paths;
public class Main {
public static void main(String[] args) {
int batchSize = 32;
int numEpochs = 2;
String datasetPath = "dataset"; // Path to your dataset folder
String modelPath = "build/model"; // Path to save the trained model
try {
// Step 1: Load the Pretrained ResNet18 Embedding
String modelUrl = "djl://ai.djl.pytorch/resnet18_embedding";
Criteria<NDList, NDList> criteria = Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrl)
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embeddingModel = criteria.loadModel();
Block baseBlock = embeddingModel.getBlock();
// Step 2: Add Fully Connected Layers
Block newBlock = new SequentialBlock()
.add(baseBlock)
.addSingleton(nd -> nd.squeeze(new int[]{2, 3})) // Squeeze dimensions
.add(Linear.builder().setUnits(8).build()) // Fully connected layer
.addSingleton(nd -> nd.softmax(1)); // Apply softmax
Model model = Model.newInstance("AuroraClassifier");
model.setBlock(newBlock);
// Step 3: Load Datasets
RandomAccessDataset trainDataset = loadDataset(datasetPath + "/train", batchSize, Dataset.Usage.TRAIN);
RandomAccessDataset testDataset = loadDataset(datasetPath + "/test", batchSize, Dataset.Usage.TEST);
trainDataset.prepare();
testDataset.prepare();
// Step 4: Configure Trainer
DefaultTrainingConfig config = setupTrainingConfig(baseBlock);
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(batchSize, 3, 224, 224));
// Step 5: Train the Model
EasyTrain.fit(trainer, numEpochs, trainDataset, testDataset);
// Step 6: Save the Trained Model
model.save(Paths.get(modelPath), "AuroraClassifier");
System.out.println("Model saved to " + modelPath);
// Cleanup
model.close();
embeddingModel.close();
} catch (Exception e) {
e.printStackTrace();
}
}
// Function to Load Dataset
private static RandomAccessDataset loadDataset(String path, int batchSize, Dataset.Usage usage) {
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
return ImageFolder.builder()
.setRepository(Repository.newInstance("", Paths.get(path)))
.optMaxDepth(1)
.addTransform(new RandomResizedCrop(256, 256)) // Augment for training
.addTransform(new RandomFlipLeftRight()) // Augment for training
.addTransform(new Resize(256, 256)) // Ensure consistent size
.addTransform(new CenterCrop(224, 224)) // Crop for embedding compatibility
.addTransform(new ToTensor())
.addTransform(new Normalize(mean, std)) // Normalize
.setSampling(batchSize, true) // Batch size and shuffle
.build();
}
// Function to Set Up Training Configuration
private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) {
String outputDir = "build/output";
// Save Model Listener
SaveModelTrainingListener saveListener = new SaveModelTrainingListener(outputDir);
saveListener.setSaveModelCallback(trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
System.out.println("Model saved after epoch. Accuracy: " + accuracy + ", Loss: " + result.getValidateLoss());
});
// Training Config
DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss())
.addEvaluator(new Accuracy()) // Add accuracy evaluator
.optDevices(Engine.getInstance().getDevices()) // Automatically select devices
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) // Logging listener
.addTrainingListeners(saveListener); // Save model listener
// Learning Rate Tracker
float lr = 0.001f;
FixedPerVarTracker.Builder lrBuilder = FixedPerVarTracker.builder().setDefaultValue(lr);
for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) {
lrBuilder.put(paramPair.getValue().getId(), 0.1f * lr); // Adjust pretrained layers' learning rate
}
config.optOptimizer(Adam.builder().optLearningRateTracker(lrBuilder.build()).build());
return config;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment