Last active
July 29, 2023 21:55
-
-
Save tucan9389/efb96fea6cfca16f8b452251a289f938 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
import coremltools as ct | |
target_dir_path = # in my case, this path was output model path from tensorflow model maker | |
saved_model_path = os.path.join(target_dir_path, 'saved_model') | |
label_file_path = os.path.join(target_dir_path, 'labels.txt') | |
# convert tensorflow model into Core ML model | |
def get_bias_and_scale(preprocessing_type='scale_0_1'): | |
if preprocessing_type == 'scale_0_1': | |
return [0,0,0], 1/255.0 | |
elif preprocessing_type == 'scale_m1_p1': | |
return [-1,-1,-1], 1/127.5 | |
elif preprocessing_type == 'pytorch': | |
return [- 0.485/(0.229) , - 0.456/(0.224), - 0.406/(0.225)], 1/(0.226*255.0) | |
elif preprocessing_type == 'none': | |
return None, None | |
return [0,0,0], 1/255.0 | |
bias, scale = get_bias_and_scale('scale_0_1') | |
if bias is not None and scale is not None: | |
inputs = [ct.ImageType(bias=bias, scale=scale)] | |
else: | |
inputs = [ct.ImageType()] | |
class_labels = [label for label in open(label_file_path, 'r').read().split('\n') if len(label)>0] | |
classifier_config = ct.ClassifierConfig(class_labels) | |
mlmodel = ct.convert(saved_model_path, inputs=inputs, classifier_config=classifier_config) | |
mlmodel.save("MyModel.mlmodel") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment