Created
September 25, 2020 14:19
-
-
Save armamut/84b406412d1be598a686cd2866027a56 to your computer and use it in GitHub Desktop.
Focal Loss for Lightgbm using functional approach for derivatives
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def focal_loss_lgb(y_true, y_pred, alpha=0.25, gamma=1.0): | |
| # Alpha or (1-alpha) | |
| a = y_true * alpha + (1-y_true) * (1-alpha) | |
| # Sign | |
| s = y_true * 2 - 1 | |
| # Gamma | |
| g = gamma | |
| # Prob. | |
| p = 1/(1+np.exp(-y_pred)) | |
| p = y_true * p + (1-y_true) * (1-p) | |
| # First derivative of p wrt. x | |
| dpdx = p*(1-p) | |
| # Second derivative of p wrt. x | |
| d2pdx2 = dpdx*(1-2*p) | |
| logp = np.log(p) | |
| # First derivative of focal loss wrt. p | |
| # https://www.wolframalpha.com/input/?i=d%2Fdp+-a*log%28p%29*%281-p%29**g | |
| dfdp = s * (a * (1-p)**(g-1) * (g*p*logp + p - 1)) / p | |
| # Second derivative of focal loss wrt. p | |
| # https://www.wolframalpha.com/input/?i=d2%2Fdp2+-a*log%28p%29*%281-p%29**g | |
| d2fdp2 = (-a * (1-p)**(g-2) * ((g-1)*g*p*p*logp + s*(p-1)*((2*g - 1)*p + 1))) / (p*p) | |
| grad = dfdp * dpdx | |
| hess = s * (d2fdp2 * dpdx * dpdx + dfdp * d2pdx2) | |
| return grad, hess | |
| def focal_loss_lgb_eval(y_true, y_pred, alpha=0.25, gamma=1.0): | |
| # Alpha or (1-alpha) | |
| a = y_true * alpha + (1-y_true) * (1-alpha) | |
| g = gamma | |
| p = 1/(1+np.exp(-y_pred)) | |
| p = y_true * p + (1-y_true) * (1-p) | |
| # Focal loss | |
| # https://arxiv.org/pdf/1708.02002.pdf | |
| f = -a*np.log(p)*(1-p)**g | |
| return 'focal_loss', np.mean(f), False | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment