Created
February 16, 2023 23:28
-
-
Save Vadbeg/965708ed9ed91b8b863f99b9df10f898 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import coremltools as ct | |
import numpy as np | |
import torch | |
from coremltools.models.neural_network import quantization_utils | |
from PIL import Image | |
from colorizers import siggraph17, load_img | |
from colorizers.siggraph17 import ModelHead | |
def test_core_model( | |
coreml_model: ct.models.MLModel, | |
torchscript_model: torch.jit.TracedModule, | |
image_path: str | |
): | |
image: Image.Image = Image.open(image_path) | |
image = image.resize(size=(256, 256)) | |
image = image.convert(mode='RGB') | |
image_array = np.array(image) | |
img_torch = torch.tensor(np.transpose(image_array, axes=[2, 0, 1])) / 255.0 | |
img_torch = img_torch.unsqueeze(0) | |
res_torch = torchscript_model(img_torch).detach().cpu().numpy() | |
res_coreml = coreml_model.predict( | |
data={ | |
'image_small': image | |
} | |
)['colorized_small_image'] | |
try: | |
print(f'Torch shape: {res_torch.shape}') | |
print(f'CoreML shape: {res_coreml.shape}') | |
np.testing.assert_allclose(res_torch, res_coreml, rtol=0.1, atol=0.1) | |
except Exception as exc: | |
print(exc) | |
def test_tail_model( | |
coreml_model: ct.models.MLModel, | |
torchscript_model: torch.jit.TracedModule, | |
image_path: str | |
): | |
image: Image.Image = Image.open(image_path) | |
image = image.convert(mode='RGB') | |
image_array = np.array(image) | |
img_torch = torch.tensor(np.transpose(image_array, axes=[2, 0, 1])) / 255.0 | |
img_torch = img_torch.unsqueeze(0) | |
res_torch = torchscript_model(img_torch, img_torch).detach().cpu().numpy()[0] | |
res_coreml = coreml_model.predict( | |
data={ | |
'image': image, | |
'colorized_small_image': image | |
} | |
)['colorized_image'][0] | |
try: | |
print(f'Torch shape: {res_torch.shape}') | |
print(f'CoreML shape: {res_coreml.shape}') | |
np.testing.assert_allclose(res_torch, res_coreml, rtol=0.01, atol=0.01) | |
except Exception as exc: | |
print(exc) | |
def convert_core_model(): | |
siggraph17_model = siggraph17(pretrained=True).eval() | |
example_input = [ | |
torch.rand(1, 3, 256, 256) | |
] | |
traced_model = torch.jit.trace(siggraph17_model, example_input) | |
coreml_model = ct.convert( | |
traced_model, | |
inputs=[ | |
ct.ImageType( | |
name="image_small", | |
shape=(1, 3, 256, 256), | |
scale=1 / 255. | |
), | |
] | |
) | |
spec = coreml_model.get_spec() | |
current_output_names = coreml_model.output_description._fd_spec | |
old_name = current_output_names[0].name | |
new_name = 'colorized_small_image' | |
ct.utils.rename_feature( | |
spec, old_name, new_name, rename_outputs=True | |
) | |
spec.description.input[0].shortDescription = 'Input RGB image, with shape (3, 256, 256)' | |
spec.description.output[0].shortDescription = 'Colorized small image as array, with shape (3, 256, 256)' | |
coreml_model = ct.models.MLModel(spec) | |
coreml_model.author = 'Vadim Titko' | |
coreml_model.license = 'AI Future' | |
coreml_model.short_description = 'Backbone colorization model for grayscale images' | |
coreml_model.version = '1.0' | |
coreml_model = quantization_utils.quantize_weights(full_precision_model=coreml_model, nbits=16) | |
test_core_model( | |
coreml_model=coreml_model, | |
torchscript_model=traced_model, | |
image_path='imgs/ansel_adams.jpg' | |
) | |
return coreml_model | |
def convert_tail_model(): | |
model_head = ModelHead() | |
example_input = [ | |
torch.rand(1, 3, 512, 512), | |
torch.rand(1, 3, 512, 512), | |
] | |
tail_traced_model = torch.jit.trace(model_head, example_input) # | |
tail_coreml_model = ct.convert( | |
tail_traced_model, | |
inputs=[ | |
ct.ImageType( | |
name="image", | |
shape=(1, 3, ct.RangeDim(), ct.RangeDim()), | |
scale=1 / 255. | |
), | |
ct.ImageType( | |
name="colorized_small_image", | |
shape=(1, 3, ct.RangeDim(), ct.RangeDim()), | |
scale=1 / 255. | |
), | |
], | |
) | |
spec = tail_coreml_model.get_spec() | |
current_output_names = tail_coreml_model.output_description._fd_spec | |
old_name = current_output_names[0].name | |
new_name = 'colorized_image' | |
ct.utils.rename_feature( | |
spec, old_name, new_name, rename_outputs=True | |
) | |
spec.description.input[0].shortDescription = 'Input RGB gray image, with shape (3, H, W)' | |
spec.description.input[1].shortDescription = 'Input RGB colorized with of size (3, H, W)' | |
spec.description.output[0].shortDescription = 'Array image, with shape (3, H, W), with values from 0.0 to 255.0' | |
tail_coreml_model = ct.models.MLModel(spec) | |
tail_coreml_model.author = 'Vadim Titko' | |
tail_coreml_model.license = 'AI Future' | |
tail_coreml_model.short_description = 'Tail colorization model for grayscale images' | |
tail_coreml_model.version = '1.0' | |
test_tail_model( | |
coreml_model=tail_coreml_model, | |
torchscript_model=tail_traced_model, | |
image_path='imgs/ansel_adams.jpg' | |
) | |
return tail_coreml_model | |
if __name__ == '__main__': | |
print('Convert core colorization model') | |
core_model = convert_core_model() | |
core_model.save("colorizer_core.mlmodel") | |
print('Convert tail colorization model') | |
tail_model = convert_tail_model() | |
tail_model.save("colorizer_tail.mlmodel") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment