Skip to content

Instantly share code, notes, and snippets.

@gsoykan
Created March 24, 2024 18:15
Show Gist options
  • Save gsoykan/c9110ce41f726f8115eed7b480a25493 to your computer and use it in GitHub Desktop.
Save gsoykan/c9110ce41f726f8115eed7b480a25493 to your computer and use it in GitHub Desktop.
Notes on KL Div & KL Div Loss

KL divergence loss, or Kullback-Leibler divergence loss, measures how one probability distribution diverges from a second, expected probability distribution. It's often used in scenarios where you want to compare two probability distributions, typically in the context of machine learning models like neural networks.

In PyTorch, the KL divergence loss is implemented through the torch.nn.KLDivLoss class. This loss function computes the divergence between two distributions by comparing a target probability distribution ( Q(x) ) with a predicted probability distribution ( P(x) ). The formula for KL divergence is given by:

[ D_{KL}(P ,||, Q) = \sum P(x) \log \left(\frac{P(x)}{Q(x)}\right) ]

where:

  • ( P(x) ) is the true probability distribution.
  • ( Q(x) ) is the predicted probability distribution.
  • ( D_{KL}(P ,||, Q) ) is the KL divergence, indicating how much information is lost when ( Q(x) ) is used to approximate ( P(x) ).

In practice, for neural networks, ( P(x) ) often represents the ground truth distribution, and ( Q(x) ) represents the distribution predicted by the network. KL divergence is particularly popular in tasks like model regularization, variational inference, and working with generative models, such as Variational Autoencoders (VAEs).

One important aspect of the KL divergence is that it is not symmetric. This means that ( D_{KL}(P ,||, Q) ) is not necessarily equal to ( D_{KL}(Q ,||, P) ), indicating that the "distance" from ( P ) to ( Q ) can be different from the "distance" from ( Q ) to ( P ).

In PyTorch, before applying KLDivLoss, you need to ensure that the log probabilities (for the predicted distribution) and the true probabilities are appropriately calculated, as the function expects log probabilities for the input and probabilities for the target.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment