Created
July 27, 2024 17:30
-
-
Save EteimZ/25312724291b97fe6fde3c6e8806a9a2 to your computer and use it in GitHub Desktop.
This is a manim visualization for LoRA inference.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Inference(Scene): | |
def construct(self): | |
title = Title("LoRA Inference", include_underline=False).scale(0.75) | |
self.add(title) | |
# This computes the values of the matrices used in the animation | |
self.computation() | |
weight_matrix = self.create_weight_matrix("Frozen Weight", BLUE, self.frozen_weight) | |
composed_matrix = self.create_weight_matrix("Adapter Weight", RED, self.composed_weight).move_to(2 * RIGHT) | |
A = self.create_A_matrix(self.A_values) | |
B = self.create_B_matrix(self.B_values) | |
dot = Dot().move_to(np.array([2.2, 0.2, 0.])) | |
matrices = VGroup(A, B) | |
B.move_to(4 * RIGHT) | |
A.move_to(RIGHT) | |
weight_matrix.move_to(2 * LEFT) | |
self.play(Create(weight_matrix)) | |
self.play(Create(A)) | |
self.play(Create(dot)) | |
self.play(Create(B)) | |
self.wait(1.5) | |
self.play(ReplacementTransform(matrices, composed_matrix), Uncreate(dot)) | |
add = Text("+").move_to(np.array([0.0, 0.2, 0.])) | |
self.play(Write(add)) | |
weight_matrices = VGroup(weight_matrix, composed_matrix) | |
weight_adapter_matrix = self.create_weight_matrix("Frozen + Adapter Weights", YELLOW, self.frozen_composed_weight) | |
self.wait(2) | |
self.play(Unwrite(add)) | |
self.play(ReplacementTransform(weight_matrices, weight_adapter_matrix)) | |
self.wait(5) | |
def create_A_matrix(self, values): | |
weight_matrix = VGroup() | |
label = Text("A").scale(0.4) | |
col_d = Text("r").scale(0.4) | |
row_d = Text("m").scale(0.4) | |
for i in range(5): | |
for j in range(2): | |
print() | |
inp = Square(color=RED).scale(0.2).set_x(0.5 * j).set_y(0.5 * i) | |
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center()) | |
weight_matrix.add(num) | |
weight_matrix.add(inp) | |
weight_matrix.move_to(ORIGIN) | |
label.next_to(weight_matrix, UP) | |
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0])) | |
side_brace = Brace(weight_matrix, direction=np.array([-1, 0, 0])) | |
col_d.next_to(buttom_brace, DOWN) | |
row_d.next_to(side_brace, LEFT) | |
weight_matrix.add(col_d) | |
weight_matrix.add(row_d) | |
weight_matrix.add(buttom_brace) | |
weight_matrix.add(side_brace) | |
weight_matrix.add(label) | |
return weight_matrix | |
def create_B_matrix(self, values): | |
weight_matrix = VGroup() | |
label = Text("B").scale(0.4) | |
col_d = Text("n").scale(0.4) | |
row_d = Text("r").scale(0.4) | |
for i in range(2): | |
for j in range(4): | |
inp = Square(color=RED).scale(0.2).set_x(0.5 * j).set_y(0.5 * i) | |
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center()) | |
weight_matrix.add(inp) | |
weight_matrix.add(num) | |
weight_matrix.move_to(ORIGIN) | |
label.next_to(weight_matrix, UP) | |
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0])) | |
side_brace = Brace(weight_matrix, direction=np.array([1, 0, 0])) | |
col_d.next_to(buttom_brace, DOWN) | |
row_d.next_to(side_brace, RIGHT) | |
weight_matrix.add(col_d) | |
weight_matrix.add(row_d) | |
weight_matrix.add(buttom_brace) | |
weight_matrix.add(side_brace) | |
weight_matrix.add(label) | |
return weight_matrix | |
def create_weight_matrix(self, matrix_label, color, values): | |
weight_matrix = VGroup() | |
label = Text(matrix_label).scale(0.4) | |
col_d = Text("n").scale(0.4) | |
row_d = Text("m").scale(0.4) | |
for i in range(5): | |
for j in range(4): | |
inp = Square(color=color).scale(0.2).set_x(0.5 * j).set_y(0.5 * i) | |
num = DecimalNumber(values[i][j]).scale(0.4).move_to(inp.get_center()) | |
weight_matrix.add(num) | |
weight_matrix.add(inp) | |
weight_matrix.move_to(ORIGIN) | |
label.next_to(weight_matrix, UP) | |
buttom_brace = Brace(weight_matrix, direction=np.array([0, -1, 0])) | |
side_brace = Brace(weight_matrix, direction=np.array([-1, 0, 0])) | |
col_d.next_to(buttom_brace, DOWN) | |
row_d.next_to(side_brace, LEFT) | |
weight_matrix.add(col_d) | |
weight_matrix.add(row_d) | |
weight_matrix.add(buttom_brace) | |
weight_matrix.add(side_brace) | |
weight_matrix.add(label) | |
return weight_matrix | |
def computation(self): | |
"This method handles the computation of the graphics" | |
# This array represents the frozen weight | |
frozen_weight = np.array([[0.22, 0.59, 0.81, 0.01], [0.03, 0.20, 0.65, 0.54], [0.42, 0.03, 0.22, 0.51], [0.74, 0.68, 0.89, 0.09], [0.64, 0.03, 0.28, 0.22]]) | |
# The arrays below represents the decomposed matrices | |
A_values = np.array([[0.64, 0.36], [0.10, 0.28], [0.08, 0.23], [0.23, 0.29], [0.70, 0.05]]) | |
B_values = np.array([[0.65, 0.61, 0.17, 0.73], [0.37, 0.21, 0.27, 0.94]]) | |
# This array represents the composed matrix | |
composed_weight = A_values.dot(B_values) | |
# The frozen matrix is added with the composed matrix | |
frozen_composed_weight = frozen_weight + composed_weight | |
# Due to the nature of how the graphics is rendered in manim we have to reverse all array values | |
self.frozen_weight = frozen_weight[::-1] | |
self.A_values = A_values[::-1] | |
self.B_values = B_values[::-1] | |
self.composed_weight = composed_weight[::-1] | |
self.frozen_composed_weight = frozen_composed_weight[::-1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment