Skip to content

Instantly share code, notes, and snippets.

@EteimZ
Created July 27, 2024 17:30
Show Gist options
  • Save EteimZ/25312724291b97fe6fde3c6e8806a9a2 to your computer and use it in GitHub Desktop.
Save EteimZ/25312724291b97fe6fde3c6e8806a9a2 to your computer and use it in GitHub Desktop.
This is a manim visualization for LoRA inference.
class Inference(Scene):
def construct(self):
title = Title("LoRA Inference", include_underline=False).scale(0.75)
self.add(title)
# This computes the values of the matrices used in the animation
self.computation()
weight_matrix = self.create_weight_matrix("Frozen Weight", BLUE, self.frozen_weight)
composed_matrix = self.create_weight_matrix("Adapter Weight", RED, self.composed_weight).move_to(2 * RIGHT)
A = self.create_A_matrix(self.A_values)
B = self.create_B_matrix(self.B_values)
dot = Dot().move_to(np.array([2.2, 0.2, 0.]))
matrices = VGroup(A, B)
B.move_to(4 * RIGHT)
A.move_to(RIGHT)
weight_matrix.move_to(2 * LEFT)
self.play(Create(weight_matrix))
self.play(Create(A))
self.play(Create(dot))
self.play(Create(B))
self.wait(1.5)
self.play(ReplacementTransform(matrices, composed_matrix), Uncreate(dot))
add = Text("+").move_to(np.array([0.0, 0.2, 0.]))
self.play(Write(add))
weight_matrices = VGroup(weight_matrix, composed_matrix)
weight_adapter_matrix = self.create_weight_matrix("Frozen + Adapter Weights", YELLOW, self.frozen_composed_weight)
self.wait(2)
self.play(Unwrite(add))
self.play(ReplacementTransform(weight_matrices, weight_adapter_matrix))
self.wait(5)
def create_A_matrix(self, values):
weight_matrix = VGroup()
label = Text("A").scale(0.4)
col_d = Text("r").scale(0.4)
row_d = Text("m").scale(0.4)
for i in range(5):
for j in range(2):
print()
inp = Square(color=RED).scale(0.2).set_x(0.5 * j).set_y(0.5 * i)
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center())
weight_matrix.add(num)
weight_matrix.add(inp)
weight_matrix.move_to(ORIGIN)
label.next_to(weight_matrix, UP)
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0]))
side_brace = Brace(weight_matrix, direction=np.array([-1, 0, 0]))
col_d.next_to(buttom_brace, DOWN)
row_d.next_to(side_brace, LEFT)
weight_matrix.add(col_d)
weight_matrix.add(row_d)
weight_matrix.add(buttom_brace)
weight_matrix.add(side_brace)
weight_matrix.add(label)
return weight_matrix
def create_B_matrix(self, values):
weight_matrix = VGroup()
label = Text("B").scale(0.4)
col_d = Text("n").scale(0.4)
row_d = Text("r").scale(0.4)
for i in range(2):
for j in range(4):
inp = Square(color=RED).scale(0.2).set_x(0.5 * j).set_y(0.5 * i)
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center())
weight_matrix.add(inp)
weight_matrix.add(num)
weight_matrix.move_to(ORIGIN)
label.next_to(weight_matrix, UP)
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0]))
side_brace = Brace(weight_matrix, direction=np.array([1, 0, 0]))
col_d.next_to(buttom_brace, DOWN)
row_d.next_to(side_brace, RIGHT)
weight_matrix.add(col_d)
weight_matrix.add(row_d)
weight_matrix.add(buttom_brace)
weight_matrix.add(side_brace)
weight_matrix.add(label)
return weight_matrix
def create_weight_matrix(self, matrix_label, color, values):
weight_matrix = VGroup()
label = Text(matrix_label).scale(0.4)
col_d = Text("n").scale(0.4)
row_d = Text("m").scale(0.4)
for i in range(5):
for j in range(4):
inp = Square(color=color).scale(0.2).set_x(0.5 * j).set_y(0.5 * i)
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center())
weight_matrix.add(num)
weight_matrix.add(inp)
weight_matrix.move_to(ORIGIN)
label.next_to(weight_matrix, UP)
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0]))
side_brace = Brace(weight_matrix, direction=np.array([-1, 0, 0]))
col_d.next_to(buttom_brace, DOWN)
row_d.next_to(side_brace, LEFT)
weight_matrix.add(col_d)
weight_matrix.add(row_d)
weight_matrix.add(buttom_brace)
weight_matrix.add(side_brace)
weight_matrix.add(label)
return weight_matrix
def computation(self):
"This method handles the computation of the graphics"
# This array represents the frozen weight
frozen_weight = np.array([[0.22, 0.59, 0.81, 0.01], [0.03, 0.20, 0.65, 0.54], [0.42, 0.03, 0.22, 0.51], [0.74, 0.68, 0.89, 0.09], [0.64, 0.03, 0.28, 0.22]])
# The arrays below represents the decomposed matrices
A_values = np.array([[0.64, 0.36], [0.10, 0.28], [0.08, 0.23], [0.23, 0.29], [0.70, 0.05]])
B_values = np.array([[0.65, 0.61, 0.17, 0.73], [0.37, 0.21, 0.27, 0.94]])
# This array represents the composed matrix
composed_weight = A_values.dot(B_values)
# The frozen matrix is added with the composed matrix
frozen_composed_weight = frozen_weight + composed_weight
# Due to the nature of how the graphics is rendered in manim we have to reverse all array values
self.frozen_weight = frozen_weight[::-1]
self.A_values = A_values[::-1]
self.B_values = B_values[::-1]
self.composed_weight = composed_weight[::-1]
self.frozen_composed_weight = frozen_composed_weight[::-1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment