Created
August 29, 2018 20:01
-
-
Save brookisme/03e297e056a445058ba5d303a95af79f to your computer and use it in GitHub Desktop.
Pytorch Weighted Categorical Crossentropy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import pytorch_nns.helpers as h | |
import pytorch_nns.functional as f | |
# | |
# HELPERS | |
# | |
def category_weights( | |
count_dict, | |
total=None, | |
use_log=True, | |
multiplier=0.15, | |
min_weight=1.0, | |
max_weight=25.0): | |
""" category_weights | |
Args: | |
* count_dict <dict>: dictionary of category counts | |
* total <int|None>: total count (if None compute) | |
* use_log <bool [True]>: take log of distribution weight | |
* multiplier <float>: multiplier for log argument | |
* min_weight <float [1.0]>: min weight value | |
Returns: | |
* mean reduction of weighted categorical crossentropy | |
""" | |
weights={} | |
if not total: | |
total=sum(list(count_dict.values())) | |
for key in count_dict.keys(): | |
v=count_dict[key] | |
if not v: v=EPS | |
weight=multiplier*total/float(v) | |
if use_log: | |
weight=math.log(weight) | |
weights[key]=min(max_weight,max(weight,min_weight)) | |
return weights | |
# | |
# FUNCTIONAL LOSSES | |
# | |
def weighted_categorical_crossentropy(inpt,targ,weights): | |
""" weighted_categorical_crossentropy | |
Args: | |
* inpt <tensor>: network prediction | |
* targ <tensor>: network target | |
* weights<tensor|nparray|list>: category weights | |
Returns: | |
* mean reduction of weighted categorical crossentropy | |
""" | |
weights=h.to_tensor(weights).float() | |
inpt=inpt/(inpt.sum(1,True)+EPS) | |
inpt=torch.clamp(inpt, EPS, 1. - EPS) | |
losses=((targ * torch.log(inpt))).float() | |
weighted_losses_transpose=weights*losses.transpose(1,-1) | |
return -weighted_losses_transpose.mean()*targ.size(1) | |
# | |
# CRITERION | |
# | |
class WeightedCategoricalCrossentropy(nn.Module): | |
""" weighted_categorical_crossentropy | |
mean reduction of weighted categorical crossentropy | |
Args: | |
* weights<tensor|nparray|list>: category weights | |
* device<str|None>: device-name. if exists, send weights to specified device | |
""" | |
def __init__(self, weights, device=None): | |
super(WeightedCategoricalCrossentropy, self).__init__() | |
self.weights=h.to_tensor(weights) | |
if device: | |
self.weights=self.weights.to(device) | |
def forward(self, inpt, targ): | |
return f.weighted_categorical_crossentropy(inpt,targ,self.weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment