Skip to content

Instantly share code, notes, and snippets.

@Ben1980
Last active November 14, 2019 19:10
Show Gist options
  • Save Ben1980/25892b090ebe0f4453bb104df6ead9b7 to your computer and use it in GitHub Desktop.
Save Ben1980/25892b090ebe0f4453bb104df6ead9b7 to your computer and use it in GitHub Desktop.
template<typename T>
class Matrix {
static_assert(std::is_arithmetic<T>::value, "T must be numeric");
public:
~Matrix() = default;
Matrix(size_t rows, size_t columns, T *m)
: nbRows(rows), nbColumns(columns), matrix(std::make_unique<T[]>(rows*columns))
{
const size_t size = nbRows*nbColumns;
std::copy(m, m + size, matrix.get());
AssertData(*this);
}
Matrix(size_t rows, size_t columns)
: nbRows(rows), nbColumns(columns), matrix(std::make_unique<T[]>(rows*columns))
{
const size_t size = nbRows*nbColumns;
std::fill(matrix.get(), matrix.get() + size, 0);
AssertData(*this);
}
Matrix(const Matrix<T> &m) : nbRows(m.nbRows), nbColumns(m.nbColumns) {
const int size = nbRows * nbColumns;
matrix = std::make_unique<T[]>(size);
std::copy(m.matrix.get(), m.matrix.get() + size, matrix.get());
}
Matrix(Matrix<T> &&m) : nbRows(std::move(m.nbRows)), nbColumns(std::move(m.nbColumns)) {
matrix.swap(m.matrix);
m.nbRows = 0;
m.nbColumns = 0;
m.matrix.release();
}
Matrix<T> & operator=(const Matrix<T> &m){
Matrix tmp(m);
nbRows = tmp.nbRows;
nbColumns = tmp.nbColumns;
matrix.reset(tmp.matrix.get());
return *this;
}
Matrix<T> & operator=(Matrix<T> &&m){
Matrix tmp(std::move(m));
std::swap(tmp.nbRows, nbRows);
std::swap(tmp.nbColumns, nbColumns);
matrix.swap(tmp.matrix);
return *this;
}
const T & operator()(size_t row, size_t column) const {
return matrix[row*nbColumns + column];
}
T & operator()(size_t row, size_t column) {
return matrix[row*nbColumns + column];
}
[[nodiscard]] size_t rows() const {
return nbRows;
}
[[nodiscard]] size_t columns() const {
return nbColumns;
}
template<typename U>
friend Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs);
private:
static void AssertData(const Matrix<T> &m) {
if(m.nbRows == 0 || m.nbColumns == 0) {
throw std::domain_error("Invalid defined matrix.");
}
if(m.nbRows != m.nbColumns) {
throw std::domain_error("Matrix is not square.");
}
}
size_t nbRows{0};
size_t nbColumns{0};
std::unique_ptr<T[]> matrix;
};
template<typename U>
Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs) {
Matrix<U>::AssertData(lhs);
Matrix<U>::AssertData(rhs);
if(lhs.rows() != rhs.rows()) {
throw std::domain_error("Matrices have unequal size.");
}
const size_t lhsRows = lhs.rows();
const size_t rhsColumns = rhs.columns();
const size_t lhsColumns = lhs.columns();
Matrix<U> C(lhsRows, rhsColumns);
for (size_t i = 0; i < lhsRows; ++i) {
for (size_t k = 0; k < rhsColumns; ++k) {
for (size_t j = 0; j < lhsColumns; ++j) {
C(i, k) += lhs(i, j) * rhs(j, k);
}
}
}
return C;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment