Last active
November 3, 2024 15:19
-
-
Save anyashka/ff0f0c26e9cbc5a3d0e903ca72cf1462 to your computer and use it in GitHub Desktop.
Updatable Lunch Classifier for Core ML
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Creating a model with TuriCreate | |
import turicreate as tc | |
data = tc.image_analysis.load_images('image/train', with_path=True) | |
data['label'] = data['path'].apply(lambda path: 'healthy' if '/healthy' in path else 'fast food') | |
model = tc.image_classifier.create(data, target='label') | |
model.save("LunchImageClassifier.model") | |
model.export_coreml('LunchImageClassifier.mlmodel') | |
# Make model updatable process | |
import coremltools | |
# Loading & Inspecting | |
coreml_model_path = './LunchImageClassifier.mlmodel' | |
spec = coremltools.utils.load_spec(coreml_model_path) | |
builder = coremltools.models.neural_network.NeuralNetworkBuilder(spec=spec) | |
builder.inspect_layers(last=3) | |
model_spec = builder.spec | |
# Marking the updatable layer & Setting loss function | |
builder.make_updatable(['fc1']) | |
builder.set_categorical_cross_entropy_loss(name="lossLayer", input="labelProbability") | |
builder.set_epochs(10, [1, 10, 50]) | |
# Chosing optimizer | |
from coremltools.models.neural_network import SgdParams | |
sgd_params = SgdParams(lr=0.001, batch=8, momentum=0) | |
sgd_params.set_batch(8, [1, 2, 8, 16]) | |
builder.set_sgd_optimizer(sgd_params) | |
# Describing & saving | |
model_spec.description.trainingInput[0].shortDescription = 'Example image of lunch' | |
model_spec.description.trainingInput[1].shortDescription = 'Associated true label of example image' | |
from coremltools.models import MLModel | |
mlmodel_updatable = MLModel(model_spec) | |
coreml_updatable_model_path = './UpdatableLunchImageClassifier.mlmodel' | |
mlmodel_updatable.save(coreml_updatable_model_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment