Created
January 15, 2025 17:51
-
-
Save raphaeldelio/aacc5133f0008d630068ca7d64d35a10 to your computer and use it in GitHub Desktop.
Fine Tuning with DJL example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package dev.raphaeldelio; | |
| import ai.djl.Model; | |
| import ai.djl.basicdataset.cv.classification.ImageFolder; | |
| import ai.djl.engine.Engine; | |
| import ai.djl.modality.cv.transform.*; | |
| import ai.djl.ndarray.NDList; | |
| import ai.djl.ndarray.types.Shape; | |
| import ai.djl.nn.Block; | |
| import ai.djl.nn.Parameter; | |
| import ai.djl.nn.SequentialBlock; | |
| import ai.djl.nn.core.Linear; | |
| import ai.djl.repository.Repository; | |
| import ai.djl.repository.zoo.Criteria; | |
| import ai.djl.repository.zoo.ZooModel; | |
| import ai.djl.training.*; | |
| import ai.djl.training.dataset.Dataset; | |
| import ai.djl.training.dataset.RandomAccessDataset; | |
| import ai.djl.training.evaluator.Accuracy; | |
| import ai.djl.training.listener.*; | |
| import ai.djl.training.loss.SoftmaxCrossEntropyLoss; | |
| import ai.djl.training.optimizer.Adam; | |
| import ai.djl.training.tracker.FixedPerVarTracker; | |
| import ai.djl.util.Pair; | |
| import ai.djl.training.util.ProgressBar; | |
| import java.nio.file.Paths; | |
| public class Main { | |
| public static void main(String[] args) { | |
| int batchSize = 32; | |
| int numEpochs = 2; | |
| String datasetPath = "dataset"; // Path to your dataset folder | |
| String modelPath = "build/model"; // Path to save the trained model | |
| try { | |
| // Step 1: Load the Pretrained ResNet18 Embedding | |
| String modelUrl = "djl://ai.djl.pytorch/resnet18_embedding"; | |
| Criteria<NDList, NDList> criteria = Criteria.builder() | |
| .setTypes(NDList.class, NDList.class) | |
| .optModelUrls(modelUrl) | |
| .optEngine("PyTorch") | |
| .optProgress(new ProgressBar()) | |
| .build(); | |
| ZooModel<NDList, NDList> embeddingModel = criteria.loadModel(); | |
| Block baseBlock = embeddingModel.getBlock(); | |
| // Step 2: Add Fully Connected Layers | |
| Block newBlock = new SequentialBlock() | |
| .add(baseBlock) | |
| .addSingleton(nd -> nd.squeeze(new int[]{2, 3})) // Squeeze dimensions | |
| .add(Linear.builder().setUnits(8).build()) // Fully connected layer | |
| .addSingleton(nd -> nd.softmax(1)); // Apply softmax | |
| Model model = Model.newInstance("AuroraClassifier"); | |
| model.setBlock(newBlock); | |
| // Step 3: Load Datasets | |
| RandomAccessDataset trainDataset = loadDataset(datasetPath + "/train", batchSize, Dataset.Usage.TRAIN); | |
| RandomAccessDataset testDataset = loadDataset(datasetPath + "/test", batchSize, Dataset.Usage.TEST); | |
| trainDataset.prepare(); | |
| testDataset.prepare(); | |
| // Step 4: Configure Trainer | |
| DefaultTrainingConfig config = setupTrainingConfig(baseBlock); | |
| Trainer trainer = model.newTrainer(config); | |
| trainer.initialize(new Shape(batchSize, 3, 224, 224)); | |
| // Step 5: Train the Model | |
| EasyTrain.fit(trainer, numEpochs, trainDataset, testDataset); | |
| // Step 6: Save the Trained Model | |
| model.save(Paths.get(modelPath), "AuroraClassifier"); | |
| System.out.println("Model saved to " + modelPath); | |
| // Cleanup | |
| model.close(); | |
| embeddingModel.close(); | |
| } catch (Exception e) { | |
| e.printStackTrace(); | |
| } | |
| } | |
| // Function to Load Dataset | |
| private static RandomAccessDataset loadDataset(String path, int batchSize, Dataset.Usage usage) { | |
| float[] mean = {0.485f, 0.456f, 0.406f}; | |
| float[] std = {0.229f, 0.224f, 0.225f}; | |
| return ImageFolder.builder() | |
| .setRepository(Repository.newInstance("", Paths.get(path))) | |
| .optMaxDepth(1) | |
| .addTransform(new RandomResizedCrop(256, 256)) // Augment for training | |
| .addTransform(new RandomFlipLeftRight()) // Augment for training | |
| .addTransform(new Resize(256, 256)) // Ensure consistent size | |
| .addTransform(new CenterCrop(224, 224)) // Crop for embedding compatibility | |
| .addTransform(new ToTensor()) | |
| .addTransform(new Normalize(mean, std)) // Normalize | |
| .setSampling(batchSize, true) // Batch size and shuffle | |
| .build(); | |
| } | |
| // Function to Set Up Training Configuration | |
| private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) { | |
| String outputDir = "build/output"; | |
| // Save Model Listener | |
| SaveModelTrainingListener saveListener = new SaveModelTrainingListener(outputDir); | |
| saveListener.setSaveModelCallback(trainer -> { | |
| TrainingResult result = trainer.getTrainingResult(); | |
| Model model = trainer.getModel(); | |
| float accuracy = result.getValidateEvaluation("Accuracy"); | |
| model.setProperty("Accuracy", String.format("%.5f", accuracy)); | |
| model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); | |
| System.out.println("Model saved after epoch. Accuracy: " + accuracy + ", Loss: " + result.getValidateLoss()); | |
| }); | |
| // Training Config | |
| DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss()) | |
| .addEvaluator(new Accuracy()) // Add accuracy evaluator | |
| .optDevices(Engine.getInstance().getDevices()) // Automatically select devices | |
| .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) // Logging listener | |
| .addTrainingListeners(saveListener); // Save model listener | |
| // Learning Rate Tracker | |
| float lr = 0.001f; | |
| FixedPerVarTracker.Builder lrBuilder = FixedPerVarTracker.builder().setDefaultValue(lr); | |
| for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) { | |
| lrBuilder.put(paramPair.getValue().getId(), 0.1f * lr); // Adjust pretrained layers' learning rate | |
| } | |
| config.optOptimizer(Adam.builder().optLearningRateTracker(lrBuilder.build()).build()); | |
| return config; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment