Created
December 18, 2020 21:28
-
-
Save maciejskorski/efbef75d2c03ae8d734f9c34bdf54385 to your computer and use it in GitHub Desktop.
efficient logistic regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## train a logistic regression (images 28x28 and 10 classes) | |
w = tf.Variable(tf.random.normal(shape=(28*28,10),stddev=0.1),trainable=True) | |
optimizer = tf.optimizers.SGD(0.01) | |
@tf.function | |
def train_step(x, y): | |
with tf.GradientTape() as tape: | |
all_logits = tf.matmul(x,w) # (n_batch,n_class) | |
y_logits = tf.gather(all_logits,y,batch_dims=1) # (n_batch,) | |
logp = y_logits - tf.reduce_logsumexp(all_logits,axis=1) | |
loss = -logp | |
gradients = tape.gradient(loss,[w]) | |
optimizer.apply_gradients(zip(gradients,[w])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment