Skip to content

Instantly share code, notes, and snippets.

@ariseff
Created September 24, 2020 20:09
Show Gist options
  • Save ariseff/cbcfd58d05abe7ec388bb4bbc9914ad9 to your computer and use it in GitHub Desktop.
Save ariseff/cbcfd58d05abe7ec388bb4bbc9914ad9 to your computer and use it in GitHub Desktop.
Visualizing KL divergence asymmetry with manim
"""
Source code for this visualization of KL divergence asymmetry:
https://twitter.com/ari_seff/status/1303741288911638530
KL computation is based on
https://tuananhle.co.uk/notes/reverse-forward-kl.html
Install manim (https://github.com/3b1b/manim) and run with:
$ manim asymkl.py AsymKL -pl
"""
from copy import deepcopy
import numpy as np
import scipy as sp
from scipy.stats import norm as normal
import manimlib
from manimlib.imports import *
camera_config = {
'background_color': BLACK,
}
class Mixture:
def __init__(self, mixture_probs, means, stds):
if sum(mixture_probs) != 1:
raise ValueError('mixture_probs must sum to 1')
self.num_mixtures = len(mixture_probs)
self.mixture_probs = mixture_probs
self.means = means
self.stds = stds
def logpdf(self, xs):
if np.isscalar(xs):
xs = [xs]
mixture_logpdfs = np.zeros([len(xs), self.num_mixtures])
for mixture_idx in range(self.num_mixtures):
mixture_logpdfs[:, mixture_idx] = normal.logpdf(
xs,
loc=self.means[mixture_idx],
scale=self.stds[mixture_idx]
)
res = sp.special.logsumexp(
mixture_logpdfs + np.log(self.mixture_probs), axis=1)
if len(res) == 1:
res = res[0]
return res
def pdf(self, xs):
res = np.exp(self.logpdf(xs))
return res
def approx_kl(gmm_1, gmm_2, num_trapz_points=1000):
max_std = np.max(np.append(gmm_1.stds, gmm_2.stds))
trapz_xs_min = np.min(np.append(gmm_1.means, gmm_2.means)) - 3 * max_std
trapz_xs_max = np.max(np.append(gmm_1.means, gmm_2.means)) + 3 * max_std
xs = np.linspace(trapz_xs_min, trapz_xs_max, num_trapz_points)
ys = gmm_1.pdf(xs) * (gmm_1.logpdf(xs) - gmm_2.logpdf(xs))
return np.trapz(ys, xs)
class AsymKL(GraphScene):
CONFIG = {
'camera_config': camera_config,
'x_axis_label': None,
'y_axis_label': None,
'y_max': 1,
'y_min': -0.1,
'graph_origin': 2.5 * DOWN + 5.5 * LEFT
}
def construct(self):
self.setup_axes(animate=False)
# Initialize p
mu1_start = 4
mu2_start = 6
sigma = 0.75
p_obj = Mixture(
[0.5, 0.5], [mu1_start, mu2_start], [sigma, sigma])
last_p_obj = deepcopy(p_obj)
p_color = BLUE
p = self.get_graph(p_obj.pdf, p_color)
# Initialize q
q_mu = 5
q_sigma = 1
q_obj = Mixture([1.0], [q_mu], [q_sigma])
last_q_obj = deepcopy(q_obj)
q_color = GREEN
q = self.get_graph(q_obj.pdf, q_color)
class TempRect(Rectangle):
"""Temperature-representing rectangle."""
def __init__(self, height=1):
Rectangle.__init__(self,
height=height, width=0.7, color=ORANGE,
fill_color=ORANGE, fill_opacity=1)
rect = TempRect()
rect.move_to(1.7*DOWN + 4.7*RIGHT)
rkl_rect = TempRect()
rkl_rect.next_to(rect, 2.8*RIGHT)
kl_label = TexMobject(R'(', 'p', R'||', 'q', ')')
rkl_label = TexMobject(R'(', 'q', R'||', 'p', ')')
dkl_label = TexMobject(R'D_{\mathrm{KL}}')
label_size = 1.1
for label in [kl_label, rkl_label, dkl_label]:
label.set_color_by_tex(R'p', p_color)
label.set_color_by_tex(R'q', q_color)
label.scale(label_size)
label_sep = 0.5*BOTTOM
kl_label.next_to(rect, label_sep)
rkl_label.next_to(rkl_rect, label_sep)
dkl_label.next_to(kl_label, 1.1*LEFT)
self.add(kl_label, rkl_label, dkl_label)
my_group = VGroup(*[p, q, rect, rkl_rect])
def update_pq(group, alpha, distr):
"""Callback function to update p(x) and/or q(x)"""
nonlocal p_obj
nonlocal q_obj
if distr['p'] is not None:
# Update p
mu1_diff = distr['p'][0] - last_p_obj.means[0]
mu1_delta = interpolate(0, mu1_diff, alpha)
mu2_diff = distr['p'][1] - last_p_obj.means[1]
mu2_delta = interpolate(0, mu2_diff, alpha)
sigma_diff = distr['p'][2] - last_p_obj.stds[0]
sigma_delta = interpolate(0, sigma_diff, alpha)
new_p_obj = Mixture([0.5, 0.5],
[last_p_obj.means[0]+mu1_delta, last_p_obj.means[1]+mu2_delta],
[last_p_obj.stds[0]+sigma_delta, last_p_obj.stds[0]+sigma_delta])
p_obj = new_p_obj
new_p = self.get_graph(new_p_obj.pdf, p_color)
else:
new_p = group[0]
if distr['q'] is not None:
# Update q
mu_diff = distr['q'][0] - last_q_obj.means[0]
mu_delta = interpolate(0, mu_diff, alpha)
sigma_diff = distr['q'][1] - last_q_obj.stds[0]
sigma_delta = interpolate(0, sigma_diff, alpha)
new_q_obj = Mixture([1.0],
[last_q_obj.means[0]+mu_delta], [last_q_obj.stds[0]+sigma_delta])
q_obj = new_q_obj
new_q = self.get_graph(new_q_obj.pdf, q_color)
else:
new_q = group[1]
# New KL divergence
kl = approx_kl(p_obj, q_obj)
rkl = approx_kl(q_obj, p_obj)
# Rectangles
temp_scale = 0.25
new_rect = TempRect(height=temp_scale*kl)
new_rect.move_to(rect, aligned_edge=BOTTOM)
new_rkl_rect = TempRect(height=temp_scale*rkl)
new_rkl_rect.move_to(rkl_rect, aligned_edge=BOTTOM)
group.become(VGroup(
*[new_p, new_q, new_rect, new_rkl_rect]))
distr_seq = [{'p': [2, 8, 0.75], 'q': None},
{'p': None, 'q': [2, 1]},
{'p': None, 'q': [2, 0.7]},
{'p': None, 'q': [5, 0.7]},
{'p': None, 'q': [5, 2]},
{'p': [2, 8, 0.35], 'q': None},
{'p': None, 'q': [5, 1]},
{'p': [4, 6, 0.75], 'q': None}]
for distr in distr_seq:
self.play(
UpdateFromAlphaFunc(my_group,
lambda group, alpha : update_pq(group, alpha, distr)),
run_time=2.75)
last_p_obj = deepcopy(p_obj)
last_q_obj = deepcopy(q_obj)
self.wait(0.75)
@EricCousineau-TRI
Copy link

Thanks! For kicks and giggles, ported to use manim community edition:
https://gist.github.com/EricCousineau-TRI/68cb261bdd13184000ea3da9f505c69d

asymkl.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment