#pragma once #ifndef MATRIX_H_ #define MATRIX_H_ #include #include #include #include #include #include "vector.h" //============================================================================= // MatrixSize //============================================================================= struct MatrixSize : public std::array { MatrixSize() {} MatrixSize(size_t num_rows, size_t num_cols) { this->operator[](0) = num_rows; this->operator[](1) = num_cols; } size_t GetNumRows() const { return this->at(0); } size_t GetNumColums() const { return this->at(1); } }; //============================================================================= // Matrix // Simple data type that allows to represent a 2D matrix in column-major // order. This class also offers several operators to manipulate the data. //============================================================================= template class Matrix { public: //=========================================================================== // Matrix c-tors and d-tors //=========================================================================== Matrix(size_t num_rows, size_t num_cols, T value = static_cast(0)); Matrix(size_t num_rows, size_t num_cols, const std::vector& data); Matrix(); ~Matrix(); //=========================================================================== // Basic accessors //=========================================================================== size_t GetNumRows() const; size_t GetNumColumns() const; MatrixSize Size() const; size_t Numel() const; //=========================================================================== // Indexing //=========================================================================== std::array Index2Subscripts(size_t idx) const; size_t Subscripts2Index(size_t i, size_t j) const; T& operator()(size_t i, size_t j); const T& operator()(size_t i, size_t j) const; T& operator()(size_t idx); const T& operator()(size_t idx) const; //=========================================================================== // Accessors //=========================================================================== void SetRow(const std::vector& row, size_t i); void SetColumn(const std::vector& column, size_t j); void SetData(const std::vector& data); std::vector GetRow(size_t i) const; std::vector GetColumn(size_t j) const; std::vector& GetData(); const std::vector& GetData() const; void DeleteRow(size_t i); void DeleteColumn(size_t j); //=========================================================================== // Utils //=========================================================================== std::vector Diag() const; void Fill(T value); void Rand(unsigned int seed); void Rand(); void Identity(); void Transpose(); void Reshape(size_t num_rows, size_t num_cols); void Round(); void Floor(); void Ceil(); void Clamp(T min_value, T max_value); T Max() const; T Min() const; //=========================================================================== // Binary Matrix-Matrix operations //=========================================================================== Matrix operator*(const Matrix& rhs) const; Matrix operator+(Matrix rhs) const; Matrix operator-(Matrix rhs) const; Matrix ElementProduct(Matrix rhs) const; Matrix ElementQuotient(Matrix rhs) const; //=========================================================================== // Binary Scalar-Matrix operations //=========================================================================== Matrix operator+(T rhs) const; Matrix operator-(T rhs) const; Matrix operator*(T rhs) const; Matrix operator/(T rhs) const; Matrix operator^(T rhs) const; //=========================================================================== // Unary operations //=========================================================================== Matrix operator-() const; private: //=========================================================================== // Member variables //=========================================================================== size_t num_rows_; size_t num_cols_; std::vector data_; }; // class Matrix //============================================================================= // Overloading operator<< for ostream operations //============================================================================= // Overloaded operator<< for Matrix template inline std::ostream& operator<<(std::ostream& out, const Matrix& matrix) { for (size_t i = 0, endI = matrix.GetNumRows(); i < endI; i++) { out << matrix.GetRow(i) << std::endl; } return out; } #endif // MATRIX_H_