Skip to content

Instantly share code, notes, and snippets.

@chermehdi
Created April 26, 2019 17:54
Show Gist options
  • Save chermehdi/19f90e0bb3dc3816c7e5bbf2b595bcdb to your computer and use it in GitHub Desktop.
Save chermehdi/19f90e0bb3dc3816c7e5bbf2b595bcdb to your computer and use it in GitHub Desktop.
package io.github.chermehdi.lib.math;
import java.util.Arrays;
/**
* @author Egor Kulikov ([email protected])
*/
public class Matrix {
public static long mod = Long.MAX_VALUE;
public final long[][] data;
public final int rowCount;
public final int columnCount;
public Matrix(int rowCount, int columnCount) {
this.rowCount = rowCount;
this.columnCount = columnCount;
this.data = new long[rowCount][columnCount];
}
public Matrix(long[][] data) {
this.rowCount = data.length;
this.columnCount = data[0].length;
this.data = data;
}
public static Matrix add(Matrix first, Matrix second) {
Matrix result = new Matrix(first.rowCount, first.columnCount);
for (int i = 0; i < result.rowCount; i++) {
for (int j = 0; j < result.columnCount; j++) {
result.data[i][j] = first.data[i][j] + second.data[i][j];
if (result.data[i][j] >= mod) {
result.data[i][j] -= mod;
}
}
}
return result;
}
public static Matrix multiply(Matrix first, Matrix second) {
Matrix result = new Matrix(first.rowCount, second.columnCount);
for (int i = 0; i < first.rowCount; i++) {
for (int j = 0; j < second.rowCount; j++) {
for (int k = 0; k < second.columnCount; k++) {
result.data[i][k] = (result.data[i][k] + first.data[i][j] * second.data[j][k]) % mod;
}
}
}
return result;
}
public static Matrix fastMultiply(Matrix first, Matrix second) {
Matrix result = new Matrix(first.rowCount, second.columnCount);
for (int i = 0; i < first.rowCount; i++) {
for (int j = 0; j < second.rowCount; j++) {
for (int k = 0; k < second.columnCount; k++) {
result.data[i][k] += first.data[i][j] * second.data[j][k];
}
}
}
for (int i = 0; i < first.rowCount; i++) {
for (int j = 0; j < second.columnCount; j++) {
result.data[i][j] %= mod;
}
}
return result;
}
public static Matrix identityMatrix(int size) {
Matrix result = new Matrix(size, size);
for (int i = 0; i < size; i++) {
result.data[i][i] = 1;
}
return result;
}
public static long[] convert(long[][] matrix) {
long[] result = new long[matrix.length * matrix.length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix.length; j++) {
result[i * matrix.length + j] = matrix[i][j];
}
}
return result;
}
public static long[] sumPowers(long[] matrix, long exponent, long mod, int side) {
long[] result = new long[matrix.length];
long[] power = new long[matrix.length];
long[] temp = new long[matrix.length];
long[] temp2 = new long[matrix.length];
sumPowers(matrix, result, power, temp, temp2, exponent + 1, mod, side);
return result;
}
private static void sumPowers(long[] matrix, long[] result, long[] power, long[] temp,
long[] temp2, long exponent,
long mod, int side) {
if (exponent == 0) {
for (int i = 0; i < matrix.length; i += side + 1) {
power[i] = 1 % mod;
}
return;
}
if ((exponent & 1) == 0) {
sumPowers(matrix, result, temp, power, temp2, exponent >> 1, mod, side);
multiply(temp2, result, temp, mod, side);
add(result, temp2, mod, side);
multiply(power, temp, temp, mod, side);
} else {
sumPowers(matrix, result, temp, power, temp2, exponent - 1, mod, side);
add(result, temp, mod, side);
multiply(power, temp, matrix, mod, side);
}
}
public static long[][] convert(long[] matrix, int side) {
long[][] result = new long[side][side];
for (int i = 0; i < side; i++) {
for (int j = 0; j < side; j++) {
result[i][j] = matrix[i * side + j];
}
}
return result;
}
public static long[] power(long[] matrix, long exponent, long mod, int side) {
long[] result = new long[matrix.length];
long[] temp = new long[result.length];
power(matrix, result, temp, exponent, mod, side);
return result;
}
private static void power(long[] matrix, long[] result, long[] temp, long exponent, long mod,
int side) {
if (exponent == 0) {
for (int i = 0; i < matrix.length; i += side + 1) {
result[i] = 1 % mod;
}
return;
}
if ((exponent & 1) == 0) {
power(matrix, temp, result, exponent >> 1, mod, side);
multiply(result, temp, temp, mod, side);
} else {
power(matrix, temp, result, exponent - 1, mod, side);
multiply(result, temp, matrix, mod, side);
}
}
public static void multiply(long[] c, long[] a, long[] b, long mod, int side) {
Arrays.fill(c, 0);
for (int i = 0; i < side; i++) {
for (int j = 0; j < side; j++) {
for (int k = 0; k < side; k++) {
c[i * side + k] += a[i * side + j] * b[j * side + k];
if ((j & 3) == 3) {
c[i * side + k] %= mod;
}
}
}
}
for (int i = 0; i < c.length; i++) {
c[i] %= mod;
}
}
public static void add(long[] c, long[] a, long mod, int side) {
for (int i = 0; i < side; i++) {
for (int j = 0; j < side; j++) {
c[i * side + j] += a[i * side + j];
if (c[i * side + j] >= mod) {
c[i * side + j] -= mod;
}
}
}
}
public static long[] fastPower(long[] matrix, long exponent, long mod, int side) {
long[] result = new long[matrix.length];
long[] temp = new long[result.length];
fastPower(matrix, result, temp, exponent, mod, side);
return result;
}
private static void fastPower(long[] matrix, long[] result, long[] temp, long exponent, long mod,
int side) {
if (exponent == 0) {
for (int i = 0; i < matrix.length; i += side + 1) {
result[i] = 1;
}
return;
}
if ((exponent & 1) == 0) {
fastPower(matrix, temp, result, exponent >> 1, mod, side);
fastMultiply(result, temp, temp, mod, side);
} else {
power(matrix, temp, result, exponent - 1, mod, side);
fastMultiply(result, temp, matrix, mod, side);
}
}
public static void fastMultiply(long[] c, long[] a, long[] b, long mod, int side) {
Arrays.fill(c, 0);
for (int i = 0; i < side; i++) {
for (int j = 0; j < side; j++) {
for (int k = 0; k < side; k++) {
c[i * side + k] += a[i * side + j] * b[j * side + k];
}
}
}
for (int i = 0; i < c.length; i++) {
c[i] %= mod;
}
}
public Matrix power(long exponent) {
if (exponent == 0) {
return identityMatrix(rowCount);
}
if (exponent == 1) {
return this;
}
Matrix result = power(exponent >> 1);
result = multiply(result, result);
if ((exponent & 1) == 1) {
result = multiply(result, this);
}
return result;
}
public Matrix fastPower(long exponent) {
if (exponent == 0) {
return identityMatrix(rowCount);
}
if (exponent == 1) {
return this;
}
Matrix result = power(exponent >> 1);
result = fastMultiply(result, result);
if ((exponent & 1) == 1) {
result = fastMultiply(result, this);
}
return result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment