Skip to content

Instantly share code, notes, and snippets.

@eamartin
Last active October 8, 2020 08:01
Show Gist options
  • Save eamartin/f73782a0adaf6ba4b9c6995f5673a3a1 to your computer and use it in GitHub Desktop.
Save eamartin/f73782a0adaf6ba4b9c6995f5673a3a1 to your computer and use it in GitHub Desktop.
Straight-through estimator question
I recently read https://arxiv.org/abs/1308.3432 , and want to make sure I'm understanding the
straight-through gradient estimator correctly. In general, I'm interested in conditional computation
and propagating gradients back through non-smooth functions (or discrete distributions).
My understanding:
Let HT(x) = int(x >= 0) be the hard threshold function. For forwards propagation, use the hard threshold
function. For backwards propagation, replace all instances of HT(x) for some G(x) that has non-zero
gradient in some set of measure > 0 and that approximates HT over the domain of x's. For instance, G
can be identity function if x in [0, 1], or otherwise can be the sigmoid function.
Applying this example to a simple decision tree model.
Input: vector x
Output: scalar y
z = HT(s(x, W_s))
y = (1-z) * f_0(x, W_0) + z * f_1(x, W_1)
dy/dW_0 = (1-z) df_0/dW_0 (aka 0 if z=1)
dy/dW_1 = z df_1/dW_1 (aka 0 if z=0)
dy/dW_s = [-f_0(x, W_0) + f_1(x, W_1)] dz/dW_s ~= [-f_0(x, W_0) + f_1(x, W_1)] ds/dW_s
where the last step involves approximating HT(x) with the identity function.
Is this correct? If so, it implies:
(1) The gradient for dy/dW_s depends on the values of both f_0 and f_1. This means that all
conditional paths must be computed during training (for backwards pass), but true conditional
execution can happen during inference.
(2) There's no update of W_0 if z=1, and vice versa for W_1.
Some other questions:
* How does training with hard threshold and straight-through estimator compare to training with
soft threshold (sigmoid function), and then just sampling from the Bernoulli distribution
parameterized by sigmoid output during inference?
* What other research has been done in this area? What is the state of the art? Do all training
algorithms require evaluating all conditional branches rather than just one during training?
Where does reinforcement learning fit into the picture?
@AwesomeLemon
Copy link

Hi. Have either of you ( @eamartin, @nitishgupta) found answers to those questions? If you have, it'd be really nice of you to share them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment