Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Created January 24, 2018 05:01
Show Gist options
  • Save zhreshold/26f98aada1e277661a6aa6cefde2f6a9 to your computer and use it in GitHub Desktop.
Save zhreshold/26f98aada1e277661a6aa6cefde2f6a9 to your computer and use it in GitHub Desktop.
Weighted logistic regression output via mxnet custom operator
import mxnet as mx
class WeightedLogisticRegressionOutput(mx.operator.CustomOp):
"""
"""
def __init__(self, beta=0.5, lower=0.3, upper=0.7):
self._lower = lower
self._upper = upper
self._beta = beta
def forward(self, is_train, req, in_data, out_data, aux):
# logistic regression forward: take sigmoid
self.assign(out_data[0], req[0], mx.nd.sigmoid(in_data[0]))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
# a normal logistic regression grad
# self.assign(in_grad[0], req[0], out_data[0] - in_data[1].reshape_like(out_data[0]))
out = out_data[0]
label = in_data[1].reshape_like(out)
in_grad = out_data[0] - label
# suppose labels are three kind: 0, 0.5, 1
# for label == 0, prediction < 0.3, reduce grad
# for label == 1, prediction > 0.7, reduce grad
condition = (label < 0.5) * (out < self._lower) + (label > 0.5) * (out > self._upper)
# if beta -> 1, become normal logistic regression
weight = mx.nd.where(condition > 0,
mx.nd.abs(in_grad) ** (self._beta - 1),
mx.nd.ones_like(in_grad))
self.assign(in_grad[0], req[0], in_grad * weight)
@mx.operator.register("weighted_logistic_regression_output")
class WeightedLogisticRegressionOutputProp(mx.Operator.CustomOpProp):
def __init__(self, beta=0.5, lower=0.3, upper=0.7):
super(WeightedLogisticRegressionOutputProp, self).__init__(need_top_grad=False)
self._beta = beta
self._lower = lower
self._upper = upper
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
dshape = in_shape[0]
lshape = (in_shape[0][0],)
oshape = in_shape[0]
return [dshape, lshape], [oshape], []
def create_operator(self, ctx, shapes, dtypes):
return WeightedLogisticRegressionOutput(beta=self._beta, lower=self._lower, upper=self._upper)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment