Skip to content

Instantly share code, notes, and snippets.

@flacle
Last active May 28, 2024 10:27
Show Gist options
  • Save flacle/a93ff64b80e85d3ee715d201f5f7b7b6 to your computer and use it in GitHub Desktop.
Save flacle/a93ff64b80e85d3ee715d201f5f7b7b6 to your computer and use it in GitHub Desktop.
Manim Gradient Descent Intuition (in Papiamento)
# Manim Gradient Descent Video Intuition
# Author: Francis Laclé
# Video: https://www.youtube.com/watch?v=1cCS6uK_NH8
# Github: https://github.com/flacle
# Date: 29 Oct, 2020
from manim import *
import math
class Intro(Scene):
def construct(self):
introText = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2)
self.wait(1)
self.add(introText)
self.play(Write(introText))
self.wait(11)
self.play(FadeOut(introText))
self.wait(1)
class ThreeDSurface(ParametricSurface):
def __init__(self, **kwargs):
kwargs = {
"u_min": -2,
"u_max": 2,
"v_min": -2,
"v_max": 2,
"checkerboard_colors": [BLUE_D]
}
ParametricSurface.__init__(self, self.func, **kwargs)
def func(self, x, y):
return np.array([x,y,x**2 - y**2])
class ConYPakicoTraha(ThreeDScene):
def construct(self):
axes = ThreeDAxes(animate=True)
surface = ThreeDSurface()
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30)
self.begin_ambient_camera_rotation(rate=0.1)
self.wait(1)
self.play(ShowCreation(axes))
self.play(ShowCreation(surface))
self.wait(4)
self.move_camera(0.4*np.pi/1, -0.45*np.pi)
self.wait(4)
self.stop_ambient_camera_rotation()
self.play(FadeOut(surface))
self.play(FadeOut(axes))
class KostFunctie(Scene):
def construct(self):
J = Tex(r'$J\left(\cdot\cdot\cdot\right)$').scale(3)
Jmin = Tex(r'$\min{J\left(\cdot\cdot\cdot\right)}$').scale(3)
self.wait(1)
self.play(Write(J))
self.wait(4)
self.play(ReplacementTransform(J, Jmin))
self.wait(6)
self.play(FadeOut(Jmin))
class OnderzoekAruba(GraphScene):
CONFIG = {
"y_axis_label": r"Poblacion di Aruba",
"x_axis_label": "Aña",
"y_max": 7,
"y_min": 0,
"y_tick_frequency" : 1,
"x_max": 9,
"x_min": 0,
"axes_color" : BLUE
}
def construct(self):
data = [1,1.243735763,1.673120729,2.258542141,2.940774487,3.641230068,4.287015945,4.891799544,5.454441913,6]
self.setup_axes()
line = self.get_graph(lambda x : (5/9*x)+1,
color = RED,
x_min = 0,
x_max = 9,
label="$J(x)$")
dot_collection = VGroup()
for time, dat in enumerate(data):
dot = Dot(color=YELLOW).move_to(self.coords_to_point(time, dat))
dot_collection.add(dot)
self.play(FadeIn(dot), rate_func=rush_into)
self.wait(1)
self.play(ShowCreation(line),run_time = 2)
self.wait(1)
error_collection = VGroup()
for time, dat in enumerate(data):
error = Line(
self.coords_to_point(time, (5/9*time)+1), dot_collection[time].get_center(),
color=GREEN)
error_collection.add(error)
self.play(ShowCreation(error),run_time = 1)
self.wait(2)
self.play(
FadeOut(error_collection),
FadeOut(dot_collection),
FadeOut(line),
FadeOut(self.axes),
FadeOut(self.x_axis_labels),
FadeOut(self.y_axis_labels))
self.play()
def setup_axes(self):
GraphScene.setup_axes(self)
self.x_axis.label_direction = UP
self.y_axis.label_direction = UP
values_x = [
(0,"'09"),
(1,"'10"),
(2,"'11"),
(3,"'12"),
(4,"'13"),
(5,"'14"),
(6,"'15"),
(7,"'16"),
(8,"'17"),
(9,"'18")
]
values_y = [
(0,"100.000"),
(1,"101.000"),
(2,"102.000"),
(3,"103.000"),
(4,"104.000"),
(5,"105.000"),
(6,"106.000")
]
self.x_axis_labels = VGroup()
self.y_axis_labels = VGroup()
# pos. tex.
for x_val, x_tex in values_x:
tex = PangoText(x_tex).scale(0.6)
tex.next_to(self.coords_to_point(x_val, 0), DOWN) #Put tex on the position
self.x_axis_labels.add(tex) #Add tex in graph
for y_val, y_tex in values_y:
tex = PangoText(y_tex).scale(0.6)
tex.next_to(self.coords_to_point(0, y_val), LEFT) #Put tex on the position
self.y_axis_labels.add(tex) #Add tex in graph
self.play(
Write(self.x_axis_labels),
Write(self.x_axis),
Write(self.y_axis_labels),
Write(self.y_axis),
)
class Hypothese(Scene):
def construct(self):
# 0 , 1 , 2 , 3 , 4 , 5
h = MathTex("h_\\theta\\left(x\\right)","=","\\theta_0","+","{\\theta_1}","x").scale(2)
h3= MathTex("h_\\theta\\left(3\\right)","=","\\theta_0","+","{\\theta_1}","3").scale(2)
h9= MathTex("9","=","\\theta_0","+","{\\theta_1}","3").scale(2)
self.wait(1)
self.play(Write(h))
self.wait(6)
framebox1 = SurroundingRectangle(h[4], buff = .1) # theta_1
framebox2 = SurroundingRectangle(h[2], buff = .1) # theta_0
framebox3 = SurroundingRectangle(h[0], buff = .1) # left-side
self.play(
ShowCreation(framebox1),
)
self.wait(2)
self.play(
ReplacementTransform(framebox1,framebox2),
)
self.wait(1)
self.play(
ReplacementTransform(framebox2,framebox3),
)
self.wait(3)
self.play(FadeOut(framebox3))
self.wait(1)
self.play(ReplacementTransform(h, h3))
self.wait(3)
self.play(ReplacementTransform(h3, h9))
self.wait(1)
self.play(FadeOut(h9))
class CombinacionJmin(Scene):
def construct(self):
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}\to{9}$').scale(2)
Jmin2 = Tex(r'$\min{J\left(0, 7)}\to{9}$').scale(2)
Jmin3 = Tex(r'$\min{J\left(2, 5)}\to{9}$').scale(2)
Jmin4 = Tex(r'$\min{J\left(-3, 1)}\to{9}$').scale(2)
Jmin5 = Tex(r'$\min{J\left(-2, 2)}\to{9}$').scale(2)
Jmin6 = Tex(r'$\min{J\left(-1, 3)}\to{9}$').scale(2)
self.wait(1)
self.play(Write(Jmin1))
self.wait(3)
self.play(ReplacementTransform(Jmin1, Jmin2))
self.play(ReplacementTransform(Jmin2, Jmin3))
self.play(ReplacementTransform(Jmin3, Jmin4))
self.play(ReplacementTransform(Jmin4, Jmin5))
self.play(ReplacementTransform(Jmin5, Jmin6))
self.wait(8)
self.play(FadeOut(Jmin6))
class SomDifferencia(Scene):
def construct(self):
# 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8
Jmin = MathTex("J(\\theta_{0}, \\theta_{1})", "=", "\\frac{1}{2m}", "\\sum\\limits_{i=1}^m", "(", "h_{\\theta}(x^{(i)})", "-", "y^{(i)}", ")^2").scale(1)
framebox1 = SurroundingRectangle(Jmin[3], buff = .1) # sum
framebox2 = SurroundingRectangle(Jmin[5], buff = .1) # h
framebox3 = SurroundingRectangle(Jmin[7], buff = .1) # y
self.wait(1)
self.play(Write(Jmin))
self.wait(5)
self.play(
ShowCreation(framebox1),
)
self.wait(1)
self.play(
ReplacementTransform(framebox1,framebox2),
)
self.play(
ReplacementTransform(framebox2,framebox3),
)
self.wait(2)
self.play(FadeOut(framebox3))
self.wait(9)
self.play(FadeOut(Jmin))
class GradientDescentDilanti(MovingCameraScene):
def construct(self):
gd = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2)
self.wait(3)
self.play(Write(gd))
self.wait(1)
self.play(self.camera_frame.set_width, gd.get_width() * 1.2)
self.wait(2)
self.play(FadeOut(gd))
self.wait(1)
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1)
Jmin2 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2)}$').scale(1)
Jmin3 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3)}$').scale(1)
Jmin4 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4)}$').scale(1)
Jmin5 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4, \theta_5)}$').scale(1)
Jmin6 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1)
self.play(Write(Jmin1))
self.wait(1)
self.play(ReplacementTransform(Jmin1, Jmin2))
self.play(ReplacementTransform(Jmin2, Jmin3))
self.play(ReplacementTransform(Jmin3, Jmin4))
self.play(ReplacementTransform(Jmin4, Jmin5))
self.wait(1)
self.play(ReplacementTransform(Jmin5, Jmin6))
self.wait(3)
self.play(FadeOut(Jmin6))
class KedaRipitiYGana(MovingCameraScene):
def construct(self):
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1)
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN)
self.wait(1)
self.play(Write(thetaJ), Write(simul))
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF)
t1 = brace1.get_text("ripiti te ora e converge")
self.play(
GrowFromCenter(brace1),
FadeIn(t1),
)
self.wait(9)
self.play(FadeOut(t1), FadeOut(brace1), FadeOut(simul))
self.play(self.camera_frame.set_width, thetaJ.get_width() * 1.6)
self.play(FadeOut(thetaJ))
class CordaCalculus(GraphScene):
CONFIG = {
"y_axis_label": r"$y$",
"x_axis_label": r"$x$",
"y_max": 10,
"y_min": 0,
"y_tick_frequency" : 1,
"x_max": 10,
"x_min": 0,
"axes_color" : BLUE
}
def construct(self):
self.wait(1)
deriv = Tex(r'$\frac{dy}{dx}$').scale(3)
self.play(Write(deriv))
self.wait(5)
self.play(FadeOut(deriv))
self.setup_axes(animate=True)
def graph_to_be_drawn(x):
return (x-5)**2
def dx(x):
return 2*(x-5)
parabola = self.get_graph(
lambda x: graph_to_be_drawn(x),
x_min=2,
x_max=8,
color=YELLOW,
stroke_opacity=0.5)
vt = ValueTracker(0)
def moving_dot():
x = vt.get_value()
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x)))
return d
md = always_redraw(moving_dot)
def get_w_line():
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED)
x = vt.get_value()
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x)))
# seems to be some rounding error? dx(x) is correct (manshrug)
inter = match_interpolate(0.6, -0.6, 3, 7, x)
t.rotate(math.atan2(-1,dx(x+inter)))
return t
vt.set_value(3)
line = always_redraw(get_w_line)
self.play(ShowCreation(parabola), FadeIn(md), FadeIn(line))
self.wait(1)
self.play(vt.set_value, 7, rate_func=there_and_back, run_time=4)
self.wait(1)
self.play(vt.set_value, 5, rate_func=slow_into, run_time=4)
self.wait(6)
self.play(vt.set_value, 3, rate_func=slow_into, run_time=1)
self.play(vt.set_value, 5, rate_func=slow_into, run_time=6)
self.wait(1)
self.play(FadeOut(parabola), FadeOut(md), FadeOut(line), FadeOut(self.axes))
class MinTekenMeiMei(MovingCameraScene):
def construct(self):
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1)
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN)
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF)
framebox1 = SurroundingRectangle(thetaJ[1], buff = .1) # theta_1
t1 = brace1.get_text("ripiti te ora e converge")
self.wait(1)
self.play(
Write(thetaJ),
Write(simul),
GrowFromCenter(brace1),
FadeIn(t1))
self.play(FadeIn(framebox1))
self.wait(3)
thetaPlus = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) > 0 \to$ descent').scale(1)
self.play(
FadeOut(brace1),
FadeOut(t1),
FadeOut(simul),
FadeOut(framebox1),
ReplacementTransform(thetaJ, thetaPlus))
self.wait(4)
thetaMin = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) < 0 \to$ ascent').scale(1)
self.play(ReplacementTransform(thetaPlus, thetaMin))
self.wait(4)
thetaJ2 = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1)
self.play(ReplacementTransform(thetaMin, thetaJ2))
self.wait(4)
self.play(FadeOut(thetaJ2))
class Alpha(GraphScene):
CONFIG = {
"y_axis_label": r"$y$",
"x_axis_label": r"$x$",
"y_max": 10,
"y_min": 0,
"y_tick_frequency" : 1,
"x_max": 10,
"x_min": 0,
"axes_color" : BLUE
}
def construct(self):
self.wait(1)
alpha = Tex(r'$\alpha$').scale(3).shift(0)
self.play(Write(alpha))
self.play(ApplyMethod(alpha.shift, (UP+RIGHT)*PI))
self.setup_axes(animate=True)
def graph_to_be_drawn(x):
return (x-5)**2
def dx(x):
return 2*(x-5)
parabola = self.get_graph(
lambda x: graph_to_be_drawn(x),
x_min=2,
x_max=8,
color=YELLOW,
stroke_opacity=0.5)
vt = ValueTracker(0)
def moving_dot():
x = vt.get_value()
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x)))
return d
md = always_redraw(moving_dot)
def get_w_line():
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED)
x = vt.get_value()
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x)))
# seems to be some rounding error? dx(x) is correct (manshrug)
inter = match_interpolate(0.6, -0.6, 3, 7, x)
t.rotate(math.atan2(-1,dx(x+inter)))
return t
vt.set_value(3)
line = always_redraw(get_w_line)
self.play(
ShowCreation(parabola),
FadeIn(md),
FadeIn(line),
ApplyMethod(alpha.scale, (1/2)))
alpha2 = Tex(r'$\alpha = 2.0$')
self.play(ReplacementTransform(alpha, alpha2), run_time=0.5)
self.play(vt.set_value, 6, rate_func=there_and_back, run_time=2)
alpha3 = Tex(r'$\alpha = 1.1$')
self.play(ReplacementTransform(alpha2, alpha3), run_time=0.5)
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=2)
self.play(vt.set_value, 5, rate_func=slow_into, run_time=12)
alpha4 = Tex(r'$\alpha = 0.05$')
self.play(ReplacementTransform(alpha3, alpha4), run_time=0.5)
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=6)
alpha5 = Tex(r'$\alpha = 2.2$')
self.play(ReplacementTransform(alpha4, alpha5), run_time=0.5)
self.play(vt.set_value, 3, rate_func=rush_into, run_time=2)
self.play(vt.set_value, 7, rate_func=rush_into, run_time=3)
self.wait(3)
self.play(
FadeOut(parabola),
FadeOut(md),
FadeOut(line),
FadeOut(self.axes),
FadeOut(alpha5))
class SaddlePoint(ThreeDScene):
def construct(self):
axes = ThreeDAxes(animate=True)
surface = ThreeDSurface()
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30)
self.begin_ambient_camera_rotation(rate=0.1)
self.wait(1)
self.play(ShowCreation(axes))
self.play(ShowCreation(surface))
self.wait(50)
self.play(FadeOut(surface))
self.play(FadeOut(axes))
class TipoDiGradientDescent(Scene):
def construct(self):
grad1 = Tex(r'SGD').scale(2).shift(0)
grad2 = Tex(r'RMSprop').scale(2).shift(0)
grad3 = Tex(r'Adam').scale(2).shift(0)
grad4 = Tex(r'Adadelta').scale(2).shift(0)
grad5 = Tex(r'Adagrad').scale(2).shift(0)
grad6 = Tex(r'Adamax').scale(2).shift(0)
grad7 = Tex(r'Nadam').scale(2).shift(0)
grad8 = Tex(r'Ftrl').scale(2).shift(0)
grad9 = Tex(r'BGD').scale(2).shift(0)
grad2.next_to(grad1, DOWN*1.5)
grad3.next_to(grad1, UP*1.5)
grad4.next_to(grad1, LEFT*1.5)
grad5.next_to(grad1, RIGHT*1.5)
grad6.next_to(grad4, UP*2)
grad7.next_to(grad5, UP*2)
grad8.next_to(grad5, DOWN*2)
grad9.next_to(grad4, DOWN*2)
self.wait(1)
self.play(Write(grad1))
self.play(Write(grad2))
self.play(Write(grad3))
self.play(Write(grad4))
self.play(Write(grad5))
self.play(Write(grad6))
self.play(Write(grad7))
self.play(Write(grad8))
self.play(Write(grad9))
self.wait(12)
self.play(
FadeOut(grad9),
FadeOut(grad8),
FadeOut(grad7),
FadeOut(grad6),
FadeOut(grad5),
FadeOut(grad4),
FadeOut(grad3),
FadeOut(grad2),
FadeOut(grad1)
)
class Outro(Scene):
def construct(self):
outroText = PangoText('Masha Danki!', gradient=(BLUE, GREEN)).scale(2)
self.wait(1)
self.add(outroText)
self.play(Write(outroText))
self.wait(3)
self.play(FadeOut(outroText))
self.wait(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment