Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save brookisme/03e297e056a445058ba5d303a95af79f to your computer and use it in GitHub Desktop.
Save brookisme/03e297e056a445058ba5d303a95af79f to your computer and use it in GitHub Desktop.
Pytorch Weighted Categorical Crossentropy
import torch.nn as nn
import pytorch_nns.helpers as h
import pytorch_nns.functional as f
#
# HELPERS
#
def category_weights(
count_dict,
total=None,
use_log=True,
multiplier=0.15,
min_weight=1.0,
max_weight=25.0):
""" category_weights
Args:
* count_dict <dict>: dictionary of category counts
* total <int|None>: total count (if None compute)
* use_log <bool [True]>: take log of distribution weight
* multiplier <float>: multiplier for log argument
* min_weight <float [1.0]>: min weight value
Returns:
* mean reduction of weighted categorical crossentropy
"""
weights={}
if not total:
total=sum(list(count_dict.values()))
for key in count_dict.keys():
v=count_dict[key]
if not v: v=EPS
weight=multiplier*total/float(v)
if use_log:
weight=math.log(weight)
weights[key]=min(max_weight,max(weight,min_weight))
return weights
#
# FUNCTIONAL LOSSES
#
def weighted_categorical_crossentropy(inpt,targ,weights):
""" weighted_categorical_crossentropy
Args:
* inpt <tensor>: network prediction
* targ <tensor>: network target
* weights<tensor|nparray|list>: category weights
Returns:
* mean reduction of weighted categorical crossentropy
"""
weights=h.to_tensor(weights).float()
inpt=inpt/(inpt.sum(1,True)+EPS)
inpt=torch.clamp(inpt, EPS, 1. - EPS)
losses=((targ * torch.log(inpt))).float()
weighted_losses_transpose=weights*losses.transpose(1,-1)
return -weighted_losses_transpose.mean()*targ.size(1)
#
# CRITERION
#
class WeightedCategoricalCrossentropy(nn.Module):
""" weighted_categorical_crossentropy
mean reduction of weighted categorical crossentropy
Args:
* weights<tensor|nparray|list>: category weights
* device<str|None>: device-name. if exists, send weights to specified device
"""
def __init__(self, weights, device=None):
super(WeightedCategoricalCrossentropy, self).__init__()
self.weights=h.to_tensor(weights)
if device:
self.weights=self.weights.to(device)
def forward(self, inpt, targ):
return f.weighted_categorical_crossentropy(inpt,targ,self.weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment