Skip to content

Instantly share code, notes, and snippets.

@piercus
Created January 26, 2022 09:54
Show Gist options
  • Save piercus/6322500066f4c1f99a0a57a10c3d8b64 to your computer and use it in GitHub Desktop.
Save piercus/6322500066f4c1f99a0a57a10c3d8b64 to your computer and use it in GitHub Desktop.
import warnings
import torch
import torch.nn.functional as F
import numpy as np
from mmcls.models.losses import Accuracy
from mmcls.models.builder import HEADS, build_loss, build_head
from mmcls.models.utils import is_tracing
from mmcls.models.heads.cls_head import ClsHead
@HEADS.register_module()
class MultiTaskHead(ClsHead):
"""Multi task head.
Args:
heads (tuple): Sub heads to use
name (tuple): name of each head
base_head (dict): Default dict config for heads. default: None
"""
def __init__(self,
heads,
names,
base_head=None,
**kwargs):
super(MultiTaskHead, self).__init__(**kwargs)
assert isinstance(heads, (tuple))
self.heads = []
if len(names) != len(heads):
print(names, heads)
raise Exception(f'len(names)={len(names)} should equal len(heads)={len(heads)}')
self.names = names
for (index, head) in enumerate(heads):
if base_head is not None:
head.update(base_head)
module_head = build_head(head)
self.heads.append(module_head)
self.add_module(names[index], module_head)
def forward_train(self, x, gt_label, **kwargs):
losses = dict()
for (index, head) in enumerate(self.heads):
head_loss = head.forward_train(x, gt_label[:,index], **kwargs)
losses[f'loss_{self.names[index]}'] = head_loss['loss']
return losses
def simple_test(self, x, **kwargs):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features, will be forwarded to each head
Returns:
tuple: The inference results.
the output is a tuple with one item per head
"""
all_heads = tuple(np.stack(head.simple_test(x, **kwargs), axis=0) for head in self.heads)
out = np.concatenate(all_heads, axis=1)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment