Skip to content

Instantly share code, notes, and snippets.

@EricCousineau-TRI
Forked from ariseff/asymkl.py
Created November 24, 2021 23:50
Show Gist options
  • Save EricCousineau-TRI/68cb261bdd13184000ea3da9f505c69d to your computer and use it in GitHub Desktop.
Save EricCousineau-TRI/68cb261bdd13184000ea3da9f505c69d to your computer and use it in GitHub Desktop.
Visualizing KL divergence asymmetry with manim (community edition, 0.12.0)
%%manim -qm -v WARNING AsymKL
# https://gist.github.com/ariseff/cbcfd58d05abe7ec388bb4bbc9914ad9
"""
Source code for this visualization of KL divergence asymmetry:
https://twitter.com/ari_seff/status/1303741288911638530
KL computation is based on
https://tuananhle.co.uk/notes/reverse-forward-kl.html
"""
from copy import deepcopy
import numpy as np
import scipy as sp
from scipy.stats import norm as normal
from manim import *
class Mixture:
def __init__(self, mixture_probs, means, stds):
if sum(mixture_probs) != 1:
raise ValueError('mixture_probs must sum to 1')
self.num_mixtures = len(mixture_probs)
self.mixture_probs = mixture_probs
self.means = means
self.stds = stds
def logpdf(self, xs):
if np.isscalar(xs):
xs = [xs]
mixture_logpdfs = np.zeros([len(xs), self.num_mixtures])
for mixture_idx in range(self.num_mixtures):
mixture_logpdfs[:, mixture_idx] = normal.logpdf(
xs,
loc=self.means[mixture_idx],
scale=self.stds[mixture_idx]
)
res = sp.special.logsumexp(
mixture_logpdfs + np.log(self.mixture_probs), axis=1)
if len(res) == 1:
res = res[0]
return res
def pdf(self, xs):
res = np.exp(self.logpdf(xs))
return res
def approx_kl(gmm_1, gmm_2, num_trapz_points=1000):
max_std = np.max(np.append(gmm_1.stds, gmm_2.stds))
trapz_xs_min = np.min(np.append(gmm_1.means, gmm_2.means)) - 3 * max_std
trapz_xs_max = np.max(np.append(gmm_1.means, gmm_2.means)) + 3 * max_std
xs = np.linspace(trapz_xs_min, trapz_xs_max, num_trapz_points)
ys = gmm_1.pdf(xs) * (gmm_1.logpdf(xs) - gmm_2.logpdf(xs))
return np.trapz(ys, xs)
class AsymKL(Scene):
def construct(self):
ax = Axes(
x_range=[-1, 10],
y_range=[-0.1, 1.0],
tips=False,
).scale(0.9).shift(1.25 * LEFT + 0.25 * DOWN)
self.add(ax)
# Initialize p
mu1_start = 4
mu2_start = 6
sigma = 0.75
p_obj = Mixture(
[0.5, 0.5], [mu1_start, mu2_start], [sigma, sigma])
last_p_obj = deepcopy(p_obj)
p_color = BLUE
p = ax.plot(p_obj.pdf, color=p_color)
# Initialize q
q_mu = 5
q_sigma = 1
q_obj = Mixture([1.0], [q_mu], [q_sigma])
last_q_obj = deepcopy(q_obj)
q_color = GREEN
q = ax.plot(q_obj.pdf, color=q_color)
class TempRect(Rectangle):
"""Temperature-representing rectangle."""
def __init__(self, height=1):
Rectangle.__init__(self,
height=height, width=0.7, color=ORANGE,
fill_color=ORANGE, fill_opacity=1)
rect = TempRect()
rect.move_to(2.5*DOWN + 4.75*RIGHT)
rkl_rect = TempRect()
rkl_rect.next_to(rect, 2.8*RIGHT)
kl_label = MathTex(R'(', 'p', R'||', 'q', ')')
rkl_label = MathTex(R'(', 'q', R'||', 'p', ')')
dkl_label = MathTex(R'D_{\mathrm{KL}}')
label_size = 1.1
for label in [kl_label, rkl_label, dkl_label]:
label.set_color_by_tex(R'p', p_color)
label.set_color_by_tex(R'q', q_color)
label.scale(label_size)
label_sep = 0.5*DOWN
kl_label.next_to(rect, label_sep)
rkl_label.next_to(rkl_rect, label_sep)
dkl_label.next_to(kl_label, 1.1*LEFT)
self.add(kl_label, rkl_label, dkl_label)
my_group = VGroup(*[p, q, rect, rkl_rect])
def update_pq(group, alpha, distr):
"""Callback function to update p(x) and/or q(x)"""
nonlocal p_obj
nonlocal q_obj
if distr['p'] is not None:
# Update p
mu1_diff = distr['p'][0] - last_p_obj.means[0]
mu1_delta = interpolate(0, mu1_diff, alpha)
mu2_diff = distr['p'][1] - last_p_obj.means[1]
mu2_delta = interpolate(0, mu2_diff, alpha)
sigma_diff = distr['p'][2] - last_p_obj.stds[0]
sigma_delta = interpolate(0, sigma_diff, alpha)
new_p_obj = Mixture([0.5, 0.5],
[last_p_obj.means[0]+mu1_delta, last_p_obj.means[1]+mu2_delta],
[last_p_obj.stds[0]+sigma_delta, last_p_obj.stds[0]+sigma_delta])
p_obj = new_p_obj
new_p = ax.plot(new_p_obj.pdf, color=p_color)
else:
new_p = group[0]
if distr['q'] is not None:
# Update q
mu_diff = distr['q'][0] - last_q_obj.means[0]
mu_delta = interpolate(0, mu_diff, alpha)
sigma_diff = distr['q'][1] - last_q_obj.stds[0]
sigma_delta = interpolate(0, sigma_diff, alpha)
new_q_obj = Mixture([1.0],
[last_q_obj.means[0]+mu_delta], [last_q_obj.stds[0]+sigma_delta])
q_obj = new_q_obj
new_q = ax.plot(new_q_obj.pdf, color=q_color)
else:
new_q = group[1]
# New KL divergence
kl = approx_kl(p_obj, q_obj)
rkl = approx_kl(q_obj, p_obj)
# Rectangles
temp_scale = 0.25
new_rect = TempRect(height=temp_scale*kl)
new_rect.move_to(rect, aligned_edge=DOWN)
new_rkl_rect = TempRect(height=temp_scale*rkl)
new_rkl_rect.move_to(rkl_rect, aligned_edge=DOWN)
group.become(VGroup(
*[new_p, new_q, new_rect, new_rkl_rect]))
distr_seq = [{'p': [2, 8, 0.75], 'q': None},
{'p': None, 'q': [2, 1]},
{'p': None, 'q': [2, 0.7]},
{'p': None, 'q': [5, 0.7]},
{'p': None, 'q': [5, 2]},
{'p': [2, 8, 0.35], 'q': None},
{'p': None, 'q': [5, 1]},
{'p': [4, 6, 0.75], 'q': None}]
for distr in distr_seq:
self.play(
UpdateFromAlphaFunc(my_group,
lambda group, alpha : update_pq(group, alpha, distr)),
run_time=2.75)
last_p_obj = deepcopy(p_obj)
last_q_obj = deepcopy(q_obj)
self.wait(0.75)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment