from keras_core import layers def ResBlockIdentity(filter, stride): def apply(inputs): skip = inputs x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(inputs) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(x) x = layers.BatchNormalization()(x) x = layers.Add()([x, skip]) x = layers.Activation('relu')(x) return x return apply def ResBlockConv(filter, stride): def apply(inputs): skip = inputs x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(inputs) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) x = layers.Conv2D(filter, kernel_size=3, strides=1, padding='same')(x) x = layers.BatchNormalization()(x) skip = layers.Conv2D(filter, kernel_size=1, strides=stride, padding='same')(skip) skip = layers.BatchNormalization()(skip) x = layers.Add()([x, skip]) x = layers.Activation('relu')(x) return x return apply def ResNet18Backbone(input_shape): input_im = keras.Input(shape=input_shape) x = layers.Conv2D(64, (3, 3), padding='same')(input_im) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) for i, kernel in enumerate([64, 128, 256, 512]): x = ResBlockConv(filter=kernel, stride=2 if i else 1)(x) x = ResBlockIdentity(filter=kernel, stride=1)(x) model = keras.Model( inputs=input_im, outputs=x, name='Resnet18' ) return model def get_model(): model = keras.Sequential( [ keras.layers.RandomFlip(), ResNet18Backbone(input_shape=(32, 32, 3)), keras.layers.GlobalMaxPooling2D(), keras.layers.Dense(10, activation="softmax", dtype="float32"), ] ) resnet18 = get_model()