Created
January 26, 2022 09:54
-
-
Save piercus/6322500066f4c1f99a0a57a10c3d8b64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from mmcls.models.losses import Accuracy | |
from mmcls.models.builder import HEADS, build_loss, build_head | |
from mmcls.models.utils import is_tracing | |
from mmcls.models.heads.cls_head import ClsHead | |
@HEADS.register_module() | |
class MultiTaskHead(ClsHead): | |
"""Multi task head. | |
Args: | |
heads (tuple): Sub heads to use | |
name (tuple): name of each head | |
base_head (dict): Default dict config for heads. default: None | |
""" | |
def __init__(self, | |
heads, | |
names, | |
base_head=None, | |
**kwargs): | |
super(MultiTaskHead, self).__init__(**kwargs) | |
assert isinstance(heads, (tuple)) | |
self.heads = [] | |
if len(names) != len(heads): | |
print(names, heads) | |
raise Exception(f'len(names)={len(names)} should equal len(heads)={len(heads)}') | |
self.names = names | |
for (index, head) in enumerate(heads): | |
if base_head is not None: | |
head.update(base_head) | |
module_head = build_head(head) | |
self.heads.append(module_head) | |
self.add_module(names[index], module_head) | |
def forward_train(self, x, gt_label, **kwargs): | |
losses = dict() | |
for (index, head) in enumerate(self.heads): | |
head_loss = head.forward_train(x, gt_label[:,index], **kwargs) | |
losses[f'loss_{self.names[index]}'] = head_loss['loss'] | |
return losses | |
def simple_test(self, x, **kwargs): | |
"""Inference without augmentation. | |
Args: | |
x (tuple[Tensor]): The input features, will be forwarded to each head | |
Returns: | |
tuple: The inference results. | |
the output is a tuple with one item per head | |
""" | |
all_heads = tuple(np.stack(head.simple_test(x, **kwargs), axis=0) for head in self.heads) | |
out = np.concatenate(all_heads, axis=1) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment