Created
March 8, 2020 00:53
-
-
Save PhDP/3228222357731751fc40e287f5c8dffa to your computer and use it in GitHub Desktop.
Matrix multiplication with a matrix of type Integer -> Integer -> T.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// C++17 | |
#include <iostream> | |
#include <array> | |
#include <initializer_list> | |
// Column-major matrix based on std::array. | |
template<size_t ROW, size_t COL, typename T> | |
class matrix { | |
public: | |
// Assuming the initializer lists are the right sizes. Hey, it's just a demo! | |
constexpr matrix(std::initializer_list<std::initializer_list<T>> const& init) { | |
auto row_id = 0; | |
for (auto const& row : init) { | |
auto col_id = 0; | |
for (auto const& value : row) { | |
m_arr[row_id * COL + col_id] = value; | |
++col_id; | |
} | |
++row_id; | |
} | |
} | |
constexpr auto size() const -> size_t { | |
return ROW * COL; | |
} | |
constexpr auto operator()(size_t row_id, size_t col_id) const -> T const& { | |
return m_arr.at(row_id * COL + col_id); | |
} | |
constexpr auto operator()(size_t row_id, size_t col_id) -> T& { | |
return m_arr.at(row_id * COL + col_id); | |
} | |
template<size_t C> | |
constexpr auto operator*(matrix<COL, C, T> const& other) const -> matrix<ROW, C, T> { | |
// Naive matrix multiplication: | |
auto ans = matrix<ROW, C, T>{}; | |
for (auto row = 0; row < ROW; ++row) { | |
for (auto col = 0; col < C; ++col) { | |
for (auto k = 0; k < COL; ++k) { | |
ans(row, col) += this->operator()(row, k) * other(k, col); | |
} | |
} | |
} | |
return ans; | |
} | |
private: | |
std::array<T, (ROW * COL)> m_arr; | |
}; | |
template<size_t ROW, size_t COL, typename T> | |
auto operator<<(std::ostream& os, matrix<ROW, COL, T> const& m) -> std::ostream& { | |
if (ROW == 0 || COL == 0) { | |
return os; | |
} | |
for (auto r = 0; r < ROW; ++r) { | |
auto c = 0; | |
os << '[' << m(r, c++); | |
for (; c < COL; ++c) { | |
os << ", " << m(r, c); | |
} | |
os << "]\n"; | |
} | |
return os; | |
} | |
auto main() -> int { | |
auto const A = matrix<2, 3, double>({{1, 2, 3}, {4, 5, 6}}); | |
auto const B = matrix<3, 2, double>({{7, 8}, {9, 10}, {11, 12}}); | |
std::cout << A << "\n*\n\n"; | |
std::cout << B << "\n=\n\n"; | |
std::cout << (A * B) << '\n'; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment