Skip to content

Instantly share code, notes, and snippets.

@atisheksingh
Created November 22, 2024 09:40
Show Gist options
  • Save atisheksingh/9890fcf588e8c61e3f4004453e2aba4c to your computer and use it in GitHub Desktop.
Save atisheksingh/9890fcf588e8c61e3f4004453e2aba4c to your computer and use it in GitHub Desktop.
Untitled3.ipynb
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@atisheksingh
Copy link
Author

 def build_model(self, learning_rate=0.001):
        """Build the model architecture."""
        base_model = tf.keras.applications.MobileNetV2(
            weights=self.base_model_weights,
            include_top=False,
            input_shape=self.input_shape
        )
        base_model.trainable = False

        # Create feature extraction branches
        branch1 = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
        branch2 = tf.keras.layers.GlobalMaxPooling2D()(base_model.output)

        # Merge branches
        merged = tf.keras.layers.Concatenate()([branch1, branch2])
        
        # Add dropout for regularization
        x = tf.keras.layers.Dropout(0.5)(merged)
        
        # Add dense layers with L2 regularization
        x = tf.keras.layers.Dense(
            512, 
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(0.01)
        )(x)
        x = tf.keras.layers.Dropout(0.3)(x)
        
        output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

        self.model = tf.keras.Model(inputs=base_model.input, outputs=output)
        
        # Compile with mixed precision for faster training
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        
        self.model.compile(
            optimizer=optimizer,
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment