-
-
Save innat/1ef444f8743d91f993da979c66ae31ba to your computer and use it in GitHub Desktop.
Netron.app example to visualize a tensorflow 2.x model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pip install tensorflow | |
pip install tf2onnx keras2onnx onnxmltools | |
""" | |
import os | |
import pdb | |
import json | |
import traceback | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
import tensorflow as tf # v2.4 | |
if len(tf.config.list_physical_devices('GPU')): | |
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) | |
############################################################ | |
# FOCUSNET # | |
############################################################ | |
class ConvBlock3D(tf.keras.layers.Layer): | |
def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' | |
, dilation_rate=(1,1,1) | |
, activation='relu' | |
, trainable=False | |
, dropout=None | |
, pool=False | |
, name=''): | |
super(ConvBlock3D, self).__init__(name='{}_ConvBlock3D'.format(name)) | |
self.pool = pool | |
self.conv_layer = tf.keras.Sequential(name='{}_Sequential'.format(name)) | |
for filter_id, filter_count in enumerate(filters): | |
self.conv_layer.add( | |
tf.keras.layers.Conv3D(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding | |
, dilation_rate=dilation_rate | |
, activation=activation | |
, kernel_regularizer=tf.keras.regularizers.l2(0.1) | |
, name='Conv_{}'.format(filter_id)) | |
) | |
self.conv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable, name='BNorm_{}'.format(filter_id))) | |
if filter_id == 0 and dropout is not None: | |
self.conv_layer.add(tf.keras.layers.Dropout(rate=dropout, name='DropOut_{}'.format(filter_id))) | |
if self.pool: | |
self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name)) | |
def call(self, x): | |
x = self.conv_layer(x) | |
if self.pool: | |
return x, self.pool_layer(x) | |
else: | |
return x | |
def get_config(self): | |
if self.pool: | |
return {'conv_layer': self.conv_layer, 'pool_layer': self.pool_layer} | |
else: | |
return {'conv_layer': self.conv_layer} | |
class ConvBlock3DSERes(tf.keras.layers.Layer): | |
""" | |
For channel-wise attention | |
""" | |
def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' | |
, dilation_rate=(1,1,1) | |
, activation='relu' | |
, trainable=False | |
, dropout=None | |
, pool=False | |
, squeeze_ratio=None | |
, init=False | |
, name=''): | |
super(ConvBlock3DSERes, self).__init__(name='{}_ConvBlock3DSERes'.format(name)) | |
self.init = init | |
if self.init: | |
self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' | |
, activation='relu', name='{}_ConnvBlockInit'.format(name)) | |
self.convblock_res = ConvBlock3D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding | |
, dilation_rate=dilation_rate | |
, activation=activation | |
, trainable=trainable | |
, dropout=dropout | |
, pool=False | |
, name=name | |
) | |
""" | |
Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py | |
""" | |
self.squeeze_ratio = squeeze_ratio | |
if self.squeeze_ratio is not None: | |
self.seblock = tf.keras.Sequential(name='{}_SqueezeExcitation'.format(name)) | |
self.seblock.add(tf.keras.layers.GlobalAveragePooling3D()) | |
self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0]))) | |
self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same' | |
, activation='relu')) | |
self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' | |
, activation='sigmoid')) | |
self.pool = pool | |
if self.pool: | |
self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name)) | |
def call(self, x): | |
if self.init: | |
x = self.convblock_filterequalizer(x) | |
x_res = self.convblock_res(x) | |
if self.squeeze_ratio is not None: | |
x_se = self.seblock(x_res) # squeeze and then get excitation factor | |
x_res = tf.math.multiply(x_res, x_se) # excited block | |
y = x + x_res | |
if self.pool: | |
return y, self.pool_layer(y) | |
else: | |
return y | |
def get_config(self): | |
config = { | |
'convblock_res': self.convblock_res | |
, 'seblock': self.seblock | |
} | |
if self.init: | |
config['convblock_filterequalizer'] = self.convblock_filterequalizer | |
if self.pool: | |
config['pool_layer'] = self.pool_layer | |
return config | |
class UpConvBlock3D(tf.keras.layers.Layer): | |
def __init__(self, filters, kernel_size=(2,2,2), strides=(2, 2, 2), padding='same', trainable=False, name=''): | |
super(UpConvBlock3D, self).__init__(name='{}_UpConv3D'.format(name)) | |
self.upconv_layer = tf.keras.Sequential(name='{}_Sequential'.format(name)) | |
self.upconv_layer.add(tf.keras.layers.Conv3DTranspose(filters, kernel_size, strides, padding=padding | |
, activation='relu' | |
, kernel_regularizer=tf.keras.regularizers.l2(0.1) | |
, name='UpConv_{}'.format(self.name)) | |
) | |
# self.upconv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable)) | |
def call(self, x): | |
return self.upconv_layer(x) | |
def get_config(self): | |
return {'upconv_layer': self.upconv_layer} | |
class ModelFocusNetZDil1(tf.keras.Model): | |
def __init__(self, class_count, activation='softmax', deepsup=False, trainable=False, verbose=False): | |
super(ModelFocusNetZDil1, self).__init__(name='ModelFocusNetZDil1') | |
self.verbose = verbose | |
self.deepsup = deepsup | |
dropout = [None, 0.25, 0.25, 0.25, 0.25, 0.25, None, None] | |
filters = [[10,10], [20,20]] | |
dilation_xy = [1, 2, 3, 6, 12, 18] | |
dilation_z = [1, 1, 1, 1, 1 , 1] | |
# Feat Extraction (SE-Res Blocks) | |
self.convblock1 = ConvBlock3DSERes(filters=filters[0], kernel_size=(3,3,1), dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[0], pool=True , squeeze_ratio=2, name='Block1') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(3,5,10),(3,5,10)) | |
self.convblock2 = ConvBlock3DSERes(filters=filters[0], dilation_rate=(dilation_xy[1], dilation_xy[1], dilation_z[1]), trainable=trainable, dropout=dropout[0], pool=False, squeeze_ratio=2, name='Block2') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(14,18),(12,14)) | |
# Dense ASPP | |
self.convblock3 = ConvBlock3D(filters=filters[1][:], dilation_rate=(dilation_xy[2], dilation_xy[2], dilation_z[2]), trainable=trainable, dropout=dropout[1], pool=False, name='Block3_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(24,30),(16,18)) | |
self.convblock4 = ConvBlock3D(filters=filters[1][:], dilation_rate=(dilation_xy[3], dilation_xy[3], dilation_z[3]), trainable=trainable, dropout=dropout[2], pool=False, name='Block4_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(42,54),(20,22)) | |
self.convblock5 = ConvBlock3D(filters=filters[1][:], dilation_rate=(dilation_xy[4], dilation_xy[4], dilation_z[4]), trainable=trainable, dropout=dropout[3], pool=False, name='Block5_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(78,102),(24,26)) | |
self.convblock6 = ConvBlock3D(filters=filters[1][:], dilation_rate=(dilation_xy[5], dilation_xy[5], dilation_z[5]), trainable=trainable, dropout=dropout[4], pool=False, name='Block6_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(138,176),(28,30)) | |
self.convblock7 = ConvBlock3DSERes(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[5], pool=False, squeeze_ratio=2, init=True, name='Block7') # Dim/2 (e.g. 96/2=48) (rp=(176,180),(32,44)) | |
# Upstream | |
self.convblock8 = ConvBlock3DSERes(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[6], pool=False, squeeze_ratio=2, init=True, name='Block8') # Dim/2 (e.g. 96/2=48) | |
if self.deepsup: | |
self.convblock8_1 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(1,1,1), padding='same' | |
, dilation_rate=(1,1,1) | |
, activation=activation | |
, name='Block8_1') | |
self.upconvblock9 = UpConvBlock3D(filters=filters[0][0], trainable=trainable, name='Block9_1') # Dim/1 (e.g. 96/1 = 96) | |
self.convblock9 = ConvBlock3DSERes(filters=filters[0], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[7], pool=False, squeeze_ratio=2, init=True, name='Block9') # Dim/1 (e.g. 96/1 = 96) | |
# Final | |
self.convblock10 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(1,1,1), padding='same' | |
, dilation_rate=(1,1,1) | |
, activation=activation | |
, name='Block10') | |
def call(self, x): | |
# Feat Extraction (SE-Res Blocks) | |
conv1, pool1 = self.convblock1(x) | |
conv2 = self.convblock2(pool1) | |
# Dense ASPP | |
conv3 = self.convblock3(conv2) | |
conv3_op = tf.concat([conv2, conv3], axis=-1) | |
conv4 = self.convblock4(conv3_op) | |
conv4_op = tf.concat([conv3_op, conv4], axis=-1) | |
conv5 = self.convblock5(conv4_op) | |
conv5_op = tf.concat([conv4_op, conv5], axis=-1) | |
conv6 = self.convblock6(conv5_op) | |
conv6_op = tf.concat([conv5_op, conv6], axis=-1) | |
conv7 = self.convblock7(conv6_op) | |
# Upstream | |
# Pixel-wise attention can be added here | |
conv8 = self.convblock8(tf.concat([pool1, conv7], axis=-1)) | |
if self.deepsup: | |
conv8_1 = self.convblock8_1(conv8) | |
up9 = self.upconvblock9(conv8) | |
# Pixel-wise attention can be added here | |
conv9 = self.convblock9(tf.concat([conv1, up9], axis=-1)) | |
# Final | |
conv10 = self.convblock10(conv9) | |
if self.verbose: | |
print (' ---------- Model: ', self.name) | |
print (' - x: ', x.shape) | |
print (' - conv1: ', conv1.shape) | |
print (' - conv2: ', conv2.shape) | |
print (' - conv3_op: ', conv3_op.shape) | |
print (' - conv4_op: ', conv4_op.shape) | |
print (' - conv5_op: ', conv5_op.shape) | |
print (' - conv6_op: ', conv6_op.shape) | |
print (' - conv7: ', conv7.shape) | |
print (' - conv8: ', conv8.shape) | |
print (' - conv9: ', conv9.shape) | |
print (' - conv10: ', conv10.shape) | |
if self.deepsup: | |
return conv8_1, conv10 | |
else: | |
return conv10 | |
def build_graph(self, dim): | |
x = tf.keras.Input(shape=(dim), name='{}-Input'.format(self.name)) | |
return tf.keras.Model(inputs=[x], outputs=self.call(x)) | |
def get_config(self): | |
config = { | |
'convblock1': self.convblock1 | |
, 'convblock2': self.convblock2 | |
, 'convblock3': self.convblock3 | |
, 'convblock4': self.convblock4 | |
, 'convblock5': self.convblock5 | |
, 'convblock6': self.convblock6 | |
, 'convblock7': self.convblock7 | |
, 'convblock8': self.convblock8 | |
, 'convblock9': self.convblock9 | |
, 'convblock10': self.convblock10 | |
} | |
if self.deepsup: | |
config['convblock8_1'] = self.convblock8_1 | |
return config | |
############################################################ | |
# UTILS # | |
############################################################ | |
@tf.function | |
def write_model_trace(model, X): | |
return model(X) | |
if __name__ == "__main__": | |
if 1: | |
print (' --------------------- [ModelFocusNetZDil1] --------------------- ') | |
# Using the summary() function | |
raw_shape = (140, 140, 40, 1) | |
model = ModelFocusNetZDil1(class_count=10) | |
model = model.build_graph(raw_shape) | |
model.summary(line_length=150) | |
# Using the save() function | |
_ = model(tf.ones((1,*raw_shape))) | |
model.save('ModelFocusNetZDil1') # Keras ModelSave format; unable to properly visualize .pb file in www.netron.app | |
model.save('ModelFocusNetZDil1.h5', save_format='h5') # Loads in www.netron.app for viz purposes, shows the high level blocks | |
# Using the .to_json() and .get_config() function | |
with open('ModelFocusNetZDil1.json', 'w') as fp: | |
json.dump(json.loads(model.to_json()), fp, indent=4) | |
# model.get_config() | |
# Using tf2onnx | |
import onnx | |
import tf2onnx | |
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, [tf.TensorSpec((1,*raw_shape), tf.float32, name='ModelInput')]) # tensorspec ensures that the batch size shows up properly in tf2onnx | |
model_onnx = onnx.shape_inference.infer_shapes(model_proto) | |
tf2onnx.utils.save_protobuf('ModelFocusNetZDil1_tf2onxx.onnx', model_onnx) | |
# Using tensorboard | |
tf.summary.trace_on(graph=True, profiler=False) | |
_ = write_model_trace(model, tf.ones(shape=(1,*raw_shape), dtype=tf.float32)) | |
writer = tf.summary.create_file_writer(str('ModelFocusNetZDil1_TBoard')) | |
with writer.as_default(): | |
tf.summary.trace_export(name=model.name, step=0, profiler_outdir=None) | |
writer.flush() | |
print (' - Run command --> tensorboard --logdir=ModelFocusNetZDil1_TBoard --port=6100') | |
pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pip install tensorflow | |
pip install tf2onnx | |
""" | |
import os | |
import pdb | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
import tensorflow as tf # v2.4 | |
if len(tf.config.list_physical_devices('GPU')): | |
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) | |
############################################################ | |
# INCEPTION MODEL # | |
############################################################ | |
# Model Ref: https://towardsdatascience.com/model-sub-classing-and-custom-training-loop-from-scratch-in-tensorflow-2-cc1d4f10fb4e | |
class ConvModule(tf.keras.layers.Layer): | |
def __init__(self, kernel_num, kernel_size, strides, padding='same'): | |
super(ConvModule, self).__init__() | |
self.conv = tf.keras.layers.Conv2D(kernel_num, kernel_size=kernel_size, strides=strides, padding=padding) | |
self.bn = tf.keras.layers.BatchNormalization() | |
def call(self, input_tensor): | |
x = self.conv(input_tensor) | |
x = self.bn(x) | |
x = tf.nn.relu(x) | |
return x | |
def get_config(self): | |
return {'conv': self.conv, 'bn':self.bn} | |
class InceptionModule(tf.keras.layers.Layer): | |
def __init__(self, kernel_size1x1, kernel_size3x3, name=''): | |
super(InceptionModule, self).__init__('InceptionModule_{}'.format(name)) | |
self.conv1 = ConvModule(kernel_size1x1, kernel_size=(1,1), strides=(1,1)) | |
self.conv2 = ConvModule(kernel_size3x3, kernel_size=(3,3), strides=(1,1)) | |
self.cat = tf.keras.layers.Concatenate() | |
def call(self, input_tensor): | |
x_1x1 = self.conv1(input_tensor) | |
x_3x3 = self.conv2(input_tensor) | |
x = self.cat([x_1x1, x_3x3]) | |
return x | |
def get_config(self): | |
return {'conv1':self.conv1, 'conv2':self.conv2, 'cat':self.cat} | |
class DownsampleModule(tf.keras.layers.Layer): | |
def __init__(self, kernel_size, name=''): | |
super(DownsampleModule, self).__init__(name='DownsampleModule_{}'.format(name)) | |
self.conv3 = ConvModule(kernel_size, kernel_size=(3,3), strides=(2,2), padding="valid") | |
self.pool = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2,2)) | |
self.cat = tf.keras.layers.Concatenate() | |
def call(self, input_tensor): | |
conv_x = self.conv3(input_tensor) | |
pool_x = self.pool(input_tensor) | |
return self.cat([conv_x, pool_x]) | |
def get_config(self): | |
return {'conv3':self.conv3, 'pool':self.pool, 'cat':self.cat} | |
class MiniInception(tf.keras.Model): | |
def __init__(self, num_classes=10): | |
super(MiniInception, self).__init__(name='MiniInception') | |
# the first conv module | |
self.conv_block = ConvModule(96, (3,3), (1,1)) | |
# 2 inception module and 1 downsample module | |
self.inception_block1 = InceptionModule(32, 32, name='Block1') | |
self.inception_block2 = InceptionModule(32, 48, name='Block2') | |
self.downsample_block1 = DownsampleModule(80, name='Block1') | |
# 4 inception module and 1 downsample module | |
self.inception_block3 = InceptionModule(112, 48, name='Block3') | |
self.inception_block4 = InceptionModule(96, 64, name='Block4') | |
self.inception_block5 = InceptionModule(80, 80, name='Block5') | |
self.inception_block6 = InceptionModule(48, 96, name='Block6') | |
self.downsample_block2 = DownsampleModule(96, name='Block2') | |
# 2 inception module | |
self.inception_block7 = InceptionModule(176, 160, name='Block7') | |
self.inception_block8 = InceptionModule(176, 160, name='Block8') | |
# average pooling | |
self.avg_pool = tf.keras.layers.AveragePooling2D((7,7)) | |
# model tail | |
self.flat = tf.keras.layers.Flatten() | |
self.classfier = tf.keras.layers.Dense(num_classes, activation='softmax') | |
def call(self, input_tensor): | |
# forward pass | |
x = self.conv_block(input_tensor) | |
x = self.inception_block1(x) | |
x = self.inception_block2(x) | |
x = self.downsample_block1(x) | |
x = self.inception_block3(x) | |
x = self.inception_block4(x) | |
x = self.inception_block5(x) | |
x = self.inception_block6(x) | |
x = self.downsample_block2(x) | |
x = self.inception_block7(x) | |
x = self.inception_block8(x) | |
x = self.avg_pool(x) | |
x = self.flat(x) | |
return self.classfier(x) | |
def build_graph(self, raw_shape): | |
x = tf.keras.layers.Input(shape=raw_shape) | |
return tf.keras.Model(inputs=[x], outputs=self.call(x)) | |
def get_config(self): | |
return { | |
'conv_block':self.conv_block | |
, 'inception_block1': self.inception_block1 | |
, 'inception_block2': self.inception_block2 | |
, 'downsample_block1': self.downsample_block1 | |
, 'inception_block3':self.inception_block3 | |
, 'inception_block4':self.inception_block4 | |
, 'inception_block5':self.inception_block5 | |
, 'inception_block6':self.inception_block6 | |
, 'downsample_block2': self.downsample_block2 | |
, 'inception_block7':self.inception_block7 | |
, 'inception_block8':self.inception_block8 | |
, 'avg_pool': self.avg_pool | |
, 'flat':self.flat | |
, 'classfier':self.classfier | |
} | |
############################################################ | |
# UTILS # | |
############################################################ | |
@tf.function | |
def write_model_trace(model, X): | |
return model(X) | |
if __name__ == "__main__": | |
if 1: | |
print (' --------------------- [MiniInception] --------------------- ') | |
# Using the summary() function | |
raw_shape = (32, 32, 3) | |
model = MiniInception() # tf.keras.Model | |
model = model.build_graph(raw_shape) # <class 'tensorflow.python.keras.engine.functional.Functional'> | |
model.summary(line_length=150) | |
# Using the save() function | |
_ = model(tf.ones((1,*raw_shape))) | |
model.save('MiniInception') # Keras ModelSave format; unable to properly visualize .pb file in www.netron.app | |
model.save('MiniInception.h5') # Loads in www.netron.app for viz purposes, shows the high level blocks | |
# Using the .to_json() and .get_config() function | |
with open('MiniInception.json', 'w') as fp: | |
json.dump(json.loads(model.to_json()), fp, indent=4) | |
# model.get_config() | |
# Using onnxmltools (does not support Conv3D, MaxPool3D, Conv3DTranspose) | |
import onnx | |
import onnxmltools | |
model_onnx = onnxmltools.convert_keras(model, model.name) # this is a ModelProto | |
model_onnx = onnx.shape_inference.infer_shapes(model_onnx) | |
onnxmltools.utils.save_model(model_onnx, 'MiniInception_onnxmltools.onnx') | |
# Using keras2onnx | |
import keras2onnx | |
model_onnx = keras2onnx.convert_keras(model, model.name) | |
model_onnx = onnx.shape_inference.infer_shapes(model_onnx) | |
keras2onnx.save_model(model_onnx, 'MiniInception_keras2onxx.onnx') | |
# Using tf2onnx | |
import tf2onnx | |
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, [tf.TensorSpec((1,*raw_shape), tf.float32, name='ModelInput')]) # tensorspec ensures that the batch size shows up properly in tf2onnx | |
model_onnx = onnx.shape_inference.infer_shapes(model_proto) | |
tf2onnx.utils.save_protobuf('MiniInception_tf2onxx.onnx', model_onnx) | |
# Using tensorboard | |
tf.summary.trace_on(graph=True, profiler=False) | |
_ = write_model_trace(model, tf.ones(shape=(1,*raw_shape), dtype=tf.float32)) | |
writer = tf.summary.create_file_writer(str('MiniInception_TBoard')) | |
with writer.as_default(): | |
tf.summary.trace_export(name=model.name, step=0, profiler_outdir=None) | |
writer.flush() | |
print (' - Run command --> tensorboard --logdir=MiniInception_TBoard --port=6100') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment