Skip to content

Instantly share code, notes, and snippets.

@armamut
Created September 25, 2020 14:19
Show Gist options
  • Select an option

  • Save armamut/84b406412d1be598a686cd2866027a56 to your computer and use it in GitHub Desktop.

Select an option

Save armamut/84b406412d1be598a686cd2866027a56 to your computer and use it in GitHub Desktop.
Focal Loss for Lightgbm using functional approach for derivatives
def focal_loss_lgb(y_true, y_pred, alpha=0.25, gamma=1.0):
# Alpha or (1-alpha)
a = y_true * alpha + (1-y_true) * (1-alpha)
# Sign
s = y_true * 2 - 1
# Gamma
g = gamma
# Prob.
p = 1/(1+np.exp(-y_pred))
p = y_true * p + (1-y_true) * (1-p)
# First derivative of p wrt. x
dpdx = p*(1-p)
# Second derivative of p wrt. x
d2pdx2 = dpdx*(1-2*p)
logp = np.log(p)
# First derivative of focal loss wrt. p
# https://www.wolframalpha.com/input/?i=d%2Fdp+-a*log%28p%29*%281-p%29**g
dfdp = s * (a * (1-p)**(g-1) * (g*p*logp + p - 1)) / p
# Second derivative of focal loss wrt. p
# https://www.wolframalpha.com/input/?i=d2%2Fdp2+-a*log%28p%29*%281-p%29**g
d2fdp2 = (-a * (1-p)**(g-2) * ((g-1)*g*p*p*logp + s*(p-1)*((2*g - 1)*p + 1))) / (p*p)
grad = dfdp * dpdx
hess = s * (d2fdp2 * dpdx * dpdx + dfdp * d2pdx2)
return grad, hess
def focal_loss_lgb_eval(y_true, y_pred, alpha=0.25, gamma=1.0):
# Alpha or (1-alpha)
a = y_true * alpha + (1-y_true) * (1-alpha)
g = gamma
p = 1/(1+np.exp(-y_pred))
p = y_true * p + (1-y_true) * (1-p)
# Focal loss
# https://arxiv.org/pdf/1708.02002.pdf
f = -a*np.log(p)*(1-p)**g
return 'focal_loss', np.mean(f), False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment