from keras_core import layers

def ResBlockIdentity(filter, stride):
    
    def apply(inputs):
        
        skip = inputs
        x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        
        x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Add()([x, skip])
        x = layers.Activation('relu')(x)
        
        return x
    
    return apply

def ResBlockConv(filter, stride):
    
    def apply(inputs):
        
        skip = inputs
        x = layers.Conv2D(filter, kernel_size=3, strides=stride, padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        
        x = layers.Conv2D(filter, kernel_size=3, strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        
        skip = layers.Conv2D(filter, kernel_size=1, strides=stride, padding='same')(skip)
        skip = layers.BatchNormalization()(skip)
        x = layers.Add()([x, skip])
        x = layers.Activation('relu')(x)
    
        return x
    
    return apply


def ResNet18Backbone(input_shape):
    input_im = keras.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), padding='same')(input_im)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    for i, kernel in enumerate([64, 128, 256, 512]):
        x = ResBlockConv(filter=kernel, stride=2 if i else 1)(x)
        x = ResBlockIdentity(filter=kernel, stride=1)(x)

    model = keras.Model(
        inputs=input_im, outputs=x, name='Resnet18'
    )

    return model
  
  
  def get_model():
    model = keras.Sequential(
        [
            keras.layers.RandomFlip(),
            ResNet18Backbone(input_shape=(32, 32, 3)),
            keras.layers.GlobalMaxPooling2D(),
            keras.layers.Dense(10, activation="softmax", dtype="float32"),
        ]
    )
    
resnet18 = get_model()