This is explaining stuff relevant to AOC 2021 day 6
First lets do fibonacci numbers because it's smaller (2x2 matrix instead of 9x9) and it's familiar ground.
So you can implement fibs like this:
def fib(n):
x = 0
y = 1
for _ in range(0, n):
x,y = y,x+y
return y
print(fib(1))
print(fib(2))
print(fib(3))
print(fib(4))
print(fib(5))
A matrix is a grid of numbers. I forget this every time but we do rows then cols. So:
/ 3 4 5 \
\ 7 2 3 /
is a 2x3 matrix.
A vector is just a matrix that's nx1.
Ill show the formula to multiply a 2x2 matrix with a 2x1 vector (the result is another 2x1 vector):
/ a b \ / x \ _ / ax + by \
\ c d / \ y / - \ cx + dy /
and implementation in python
def m2x2_times_v2x1(m,v):
a,b,c,d=m[0],m[1],m[2],m[3]
x,y=v[0],v[1]
return [a*x+b*y, c*x+d*y]
Now for why it's ax+by and not ax+bx or anything else there's a whole bunch of math that people spend years learning. But really all that matters here is that we can use a 2x2 matrix to take our value x,y and compute two new values.
Look at what this specific matrix does:
/ 0 1 \ / x \ _ / 0*x + 1*y \ _ / y \
\ 1 1 / \ y / - \ 1*x + 1*y / - \ x+y /
this is exactly the inner loop of the fibs program, so I can rewrite fibs to use a matrix:
def fib_m(n):
m=[0,1,
1,1]
v=[0,1]
for _ in range(0, n):
v = m2x2_times_v2x1(m, v)
return v[1]
Mathematically what we are doing to calcuate fib_m(5) (for example) is using a matrix and vector
m = [0,1; 1,1] v = [0; 1]
and the we calculate
m * m * m * m * m * v
and take the second coordinate of the vector we get.
This can be written m^5 * v.
And there is a really clever technique to efficiently calculate m^n called binary exponentiation.
https://en.wikipedia.org/wiki/Exponentiation_by_squaring
Using binary exponentiation lets you accelerate the calculation from O(n) to O(log(n)).
def fib(n):
a = 0
b = 1
for _ in range(0, n):
a,b = b,a+b
return b
print(fib(1))
print(fib(2))
print(fib(3))
print(fib(4))
print(fib(5))
print("")
def m2x2_times_v2x1(m,v):
a,b,c,d=m[0],m[1],m[2],m[3]
x,y=v[0],v[1]
return [a*x+b*y, c*x+d*y]
def fib_m(n):
m=[0,1,
1,1]
v=[0,1]
for _ in range(0, n):
v = m2x2_times_v2x1(m, v)
return v[1]
print("")
print(fib_m(1))
print(fib_m(2))
print(fib_m(3))
print(fib_m(4))
print(fib_m(5))
Interesting read, thanks for sharing.