Skip to content

Instantly share code, notes, and snippets.

@EdisonLeeeee
Last active August 22, 2022 10:53
Show Gist options
  • Save EdisonLeeeee/290691c8b1895427024875c3fafece67 to your computer and use it in GitHub Desktop.
Save EdisonLeeeee/290691c8b1895427024875c3fafece67 to your computer and use it in GitHub Desktop.
PyTorch equivalence for tf.nn.l2_normalize
import torch
import tensorflow as tf
########## PyTorch Version 1 ################
x = torch.randn(5, 6)
norm_th = x/torch.norm(x, p=2, dim=1, keepdim=True)
norm_th[torch.isnan(norm_th)] = 0 # to avoid nan
########## PyTorch Version 2 ################
norm_th = torch.nn.functional.normalize(x, p=2, dim=1)
########### Equivalent to ############
norm_tf = tf.nn.l2_normalize(x.numpy(), axis=1)
print(norm_th)
print(norm_tf)
@thakurudit
Copy link

thakurudit commented Mar 16, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment