Last active
October 8, 2020 08:01
-
-
Save eamartin/f73782a0adaf6ba4b9c6995f5673a3a1 to your computer and use it in GitHub Desktop.
Straight-through estimator question
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I recently read https://arxiv.org/abs/1308.3432 , and want to make sure I'm understanding the | |
straight-through gradient estimator correctly. In general, I'm interested in conditional computation | |
and propagating gradients back through non-smooth functions (or discrete distributions). | |
My understanding: | |
Let HT(x) = int(x >= 0) be the hard threshold function. For forwards propagation, use the hard threshold | |
function. For backwards propagation, replace all instances of HT(x) for some G(x) that has non-zero | |
gradient in some set of measure > 0 and that approximates HT over the domain of x's. For instance, G | |
can be identity function if x in [0, 1], or otherwise can be the sigmoid function. | |
Applying this example to a simple decision tree model. | |
Input: vector x | |
Output: scalar y | |
z = HT(s(x, W_s)) | |
y = (1-z) * f_0(x, W_0) + z * f_1(x, W_1) | |
dy/dW_0 = (1-z) df_0/dW_0 (aka 0 if z=1) | |
dy/dW_1 = z df_1/dW_1 (aka 0 if z=0) | |
dy/dW_s = [-f_0(x, W_0) + f_1(x, W_1)] dz/dW_s ~= [-f_0(x, W_0) + f_1(x, W_1)] ds/dW_s | |
where the last step involves approximating HT(x) with the identity function. | |
Is this correct? If so, it implies: | |
(1) The gradient for dy/dW_s depends on the values of both f_0 and f_1. This means that all | |
conditional paths must be computed during training (for backwards pass), but true conditional | |
execution can happen during inference. | |
(2) There's no update of W_0 if z=1, and vice versa for W_1. | |
Some other questions: | |
* How does training with hard threshold and straight-through estimator compare to training with | |
soft threshold (sigmoid function), and then just sampling from the Bernoulli distribution | |
parameterized by sigmoid output during inference? | |
* What other research has been done in this area? What is the state of the art? Do all training | |
algorithms require evaluating all conditional branches rather than just one during training? | |
Where does reinforcement learning fit into the picture? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi. Have either of you ( @eamartin, @nitishgupta) found answers to those questions? If you have, it'd be really nice of you to share them