Last active
May 14, 2025 02:51
-
-
Save rsomani95/7922ea1e25fe68358be230d841d272a9 to your computer and use it in GitHub Desktop.
Example of how to load in a `timm` architecture with the YOLOX experiment setup. In this file, we're looking specifically at `ghostnet_100`, but this can be extended to any other architecture in `timm` that supports the `features_only` interface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timm | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from upyog.imports import * | |
from yolox.exp.yolox_base import Exp as DefaultBaseExp | |
from yolox.models import YOLOPAFPN, YOLOX, YOLOXHead | |
from yolox.utils import get_local_rank, wait_for_the_master | |
__all__ = ["GhostNetAABaseCOCOExp"] | |
def freeze_layer(m: nn.Module): | |
for p in m.parameters(): | |
p.requires_grad = False | |
class GhostNetAABaseCOCOExp(DefaultBaseExp): | |
def __init__(self): | |
super().__init__() | |
self.enable_mixup = False | |
self.freeze_backbone = False | |
self.freeze_fpn = False | |
def get_model(self): | |
def create_bbone(): | |
return timm.create_model( | |
"ghostnet_100", | |
features_only=True, | |
out_indices=[2, 3, 4], | |
pretrained=True, | |
) | |
class TIMMWrap(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.m = create_bbone() | |
self.feature_names = ("dark3", "dark4", "dark5") | |
def forward(self, x): | |
out = self.m(x) | |
out = {k: v for k, v in zip(self.feature_names, out)} | |
return out | |
def init_yolo(M): | |
for m in M.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eps = 1e-3 | |
m.momentum = 0.03 | |
if getattr(self, "model", None) is None: | |
in_channels = create_bbone().feature_info.channels() | |
fpn = YOLOPAFPN( | |
depth=1, | |
width=1, | |
in_features=("dark3", "dark4", "dark5"), | |
in_channels=in_channels, | |
depthwise=True, | |
) | |
fpn.backbone = TIMMWrap() | |
head = YOLOXHead( | |
self.num_classes, | |
width=1, | |
in_channels=in_channels, | |
depthwise=True, | |
) | |
self.model = YOLOX(fpn, head) | |
self.model.apply(init_yolo) | |
self.model.head.initialize_biases(1e-2) | |
self.load_pretrained_model_() | |
self.freeze_model() | |
return self.model | |
def freeze_model(self): | |
from cinemanet.modelling.ghostnet import freeze_layer | |
if self.freeze_backbone: | |
freeze_layer(self.model.backbone.backbone) | |
logger.info(f"Froze backbone") | |
if self.freeze_fpn: | |
assert self.freeze_backbone, f"must freeze backbone if freezing FPN" | |
fpn_modules = [ | |
"upsample", | |
"lateral_conv0", | |
"C3_p4", | |
"reduce_conv1", | |
"C3_p3", | |
"bu_conv2", | |
"C3_n3", | |
"bu_conv1", | |
"C3_n4", | |
] | |
for fpn_module in fpn_modules: | |
freeze_layer(getattr(self.model.backbone, fpn_module)) | |
logger.info(f"Froze FPN") | |
def load_pretrained_model_(self): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment