Skip to content

Instantly share code, notes, and snippets.

@qiuwch
Created May 10, 2015 03:40
Show Gist options
  • Select an option

  • Save qiuwch/c455ccb341eb6277aee9 to your computer and use it in GitHub Desktop.

Select an option

Save qiuwch/c455ccb341eb6277aee9 to your computer and use it in GitHub Desktop.
A simple matrix implementation, useful for course project
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <complex> // std::complex, std::abs
#define CHECK_ACC
#define EPS 10e-3 // a very small value
// Float matrix
template<typename T>
class Matrix {
public:
Matrix(int height, int width, T val = 0) {
m_width = width;
m_height = height;
m_elements = (T *)malloc(width * height * sizeof(T));
constantInit(val);
}
Matrix(const Matrix& other) {
m_width = other.m_width;
m_height = other.m_height;
m_elements = (T *)malloc(m_width * m_height * sizeof(T));
memcpy(m_elements, other.m_elements, m_width * m_height * sizeof(T));
}
Matrix& operator=(const Matrix& other) {
if (this == &other) return *this;
m_width = other.m_width;
m_height = other.m_height;
free(m_elements);
m_elements = (T *)malloc(m_width * m_height * sizeof(T));
memcpy(m_elements, other.m_elements, m_width * m_height * sizeof(T));
return *this;
}
bool operator==(const Matrix& other) {
if (m_width != other.m_width) return false;
if (m_height != other.m_height) return false;
for (int y = 0; y < m_height; y++) {
for (int x = 0; x < m_width; x++) {
int i = y * m_width + x;
T diff = m_elements[i] - other.m_elements[i];
if (isnan(diff) || std::abs(diff) > EPS) { // what if the value is nan?
printf("y:%d x:%d val1:%.2f val2:%.2f\n", y, x, m_elements[i], other.m_elements[i]);
return false;
}
}
}
return true;
}
virtual ~Matrix() {
free(m_elements);
}
inline void SetVal(int y, int x, T val) { // If performance hungry, consider to switch boundary check off
#ifdef CHECK_ACC
if (y < 0 || x < 0 || y >= m_height || x >= m_width) assert(false);
#endif
m_elements[y*m_width + x] = val;
}
inline T GetVal(int y, int x) const {
#ifdef CHECK_ACC
if (y < 0 || x < 0 || y >= m_height || x >= m_width) assert(false);
#endif
return m_elements[y*m_width + x];
}
inline int Size() const { // This is count of elements, not memory storage
return m_width * m_height;
}
inline int GetWidth() const {
return m_width;
}
inline int GetHeight() const {
return m_height;
}
void RandInit() {
for (int i = 0; i < m_width * m_height; i++) {
m_elements[i] = rand() % 10; // number from 0..9
}
}
// Matrix operation
Matrix MultiplyCPU(const Matrix& other) {
if (m_width != other.m_height) return Matrix(0, 0);
Matrix C(m_height, other.m_width);
for (int y = 0; y < C.m_height; y++) {
for (int x = 0; x < C.m_width; x++) {
T val = 0;
for (int i = 0; i < m_width; i++) {
val += GetVal(y, i) * other.GetVal(i, x);
}
C.SetVal(y, x, val);
}
}
return C;
}
T Sum() {
T sum = 0;
for (int i = 0; i < m_width * m_height; i++) {
sum += m_elements[i];
}
return sum;
}
void Output(std::ostream &os, bool truncated=true) {
os << "The size of matrix is width:" << m_width << " height:" << m_height << std::endl;
int W = truncated ? std::min(m_width, 10) : m_width;
int H = truncated ? std::min(m_height, 10) : m_height;
for (int y = 0; y < H; y++) { // Only output the first 10 row/col of matrix
for (int x = 0; x < W; x++) {
os << GetVal(y, x) << " ";
}
os << std::endl;
}
}
private:
void constantInit(T val) {
for (int i = 0; i < m_width * m_height; i++) {
m_elements[i] = val;
}
}
int m_width;
int m_height;
T *m_elements;
};
template<typename T>
std::ostream& operator<<(std::ostream &os, const Matrix<T> &mat) {
os << "The size of matrix is width:" << mat.GetWidth() << " height:" << mat.GetHeight() << std::endl;
for (int y = 0; y < std::min(mat.GetHeight(), 10); y++) { // Only output the first 10 row/col of matrix
for (int x = 0; x < std::min(mat.GetWidth(), 10); x++) {
os << mat.GetVal(y, x) << " ";
}
os << std::endl;
}
return os;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment