diff --git a/Matrix.hpp b/Matrix.hpp index 3a4bea2..8c1f76c 100644 --- a/Matrix.hpp +++ b/Matrix.hpp @@ -8,13 +8,24 @@ template class Matrix { public: - Matrix(); + /** + * @brief create a matrix but leave all of its values unitialized + */ + Matrix() = default; + + /** + * @brief Create a matrix but fill all of its entries with one value + */ + Matrix(float value); /** * @brief Initialize a matrix with an array */ Matrix(const std::array &array); + /** + * @brief Initialize a matrix as a copy of another matrix + */ Matrix(const Matrix &other); // TODO: Figure out how to do this /** @@ -22,6 +33,10 @@ public: */ // template // Matrix(Args&&... args); + /** + * @brief Set all elements in this to value + */ + void Fill(float value); /** * @brief Element-wise matrix addition @@ -82,6 +97,11 @@ public: */ Matrix &ElementDivide(const Matrix &other, Matrix &result) const; + + Matrix & + MinorMatrix(Matrix &result, uint8_t row_idx, + uint8_t column_idx) const; + /** * @return Get the determinant of the matrix * @note for right now only 2x2 and 3x3 matrices are supported @@ -146,23 +166,23 @@ public: * @brief get the specified row of the matrix returned as a reference to the * internal array */ - std::array &operator[](uint8_t row_index) { - if (row_index > rows - 1) { - return this->matrix[0]; // TODO: We should throw something here instead of - // failing quietly. - } - return this->matrix[row_index]; - } + std::array &operator[](uint8_t row_index); - Matrix &operator=(const Matrix &other) { - for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { - for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); - } - } - // return a reference to ourselves so you can chain together these functions - return *this; - } + /** + * @brief Copy the contents of other into this matrix + */ + Matrix &operator=(const Matrix &other); + + /** + * @brief Return a new matrix that is the sum of this matrix and other matrix + */ + Matrix operator+(const Matrix &other) const; + + Matrix operator-(const Matrix &other) const; + + Matrix operator*(const Matrix &other) const; + + Matrix operator*(float scalar) const; private: /** @@ -175,17 +195,9 @@ private: template static float dotProduct(const Matrix &vec1, const Matrix &vec2); - /** - * @brief Set all elements in this matrix to zero - */ - void zeroMatrix(); Matrix &matrixOfMinors(Matrix &result) const; - Matrix & - minorMatrix(Matrix &result, uint8_t row_idx, - uint8_t column_idx) const; - Matrix &adjugate(Matrix &result) const; void setMatrixToArray(const std::array &array); @@ -210,8 +222,9 @@ void Matrix::setMatrixToArray( } } -template Matrix::Matrix() { - this->zeroMatrix(); +template +Matrix::Matrix(float value) { + this->Fill(value); } template @@ -321,9 +334,9 @@ Matrix::Invert(Matrix &result) const { // unfortunately we can't calculate this at compile time so we'll just reurn // zeros float determinant{this->Det()}; - if (this->Det() < 0) { + if (determinant < 0) { // you can't invert a matrix with a negative determinant - result.zeroMatrix(); + result.Fill(0); return result; } @@ -339,7 +352,7 @@ Matrix::Invert(Matrix &result) const { minors.adjugate(result); // scale the result by 1/determinant and we have our answer - result.Mult(1 / determinant); + result.Mult(1 / determinant, result); return result; } @@ -368,51 +381,27 @@ Matrix::Square(Matrix &result) const { return result; } -// explicitly define the determinant for a 3x3 matrix because it is definitely -// the fastest way to calculte a 2x2 matrix determinant +// explicitly define the determinant for a 2x2 matrix because it is definitely +// the fastest way to calculate a 2x2 matrix determinant template <> float Matrix<2, 2>::Det() const { return this->matrix[0][0] * this->matrix[1][1] - - this->matrix[0][1] * this->matrix[1][1]; -} - -// explicitly define the determinant for a 3x3 matrix because it will probably -// be faster than the jacobi method for nxn matrices -template <> float Matrix<3, 3>::Det() const { - float a{this->matrix[0][0]}; - float b{this->matrix[0][1]}; - float c{this->matrix[0][2]}; - - Matrix<2, 2> minors{}; - this->minorMatrix(minors, 0, 0); - float det = a * minors.Det(); - - this->minorMatrix(minors, 0, 1); - det -= b * minors.Det(); - - this->minorMatrix(minors, 0, 2); - det += c * minors.Det(); - - return det; + this->matrix[0][1] * this->matrix[1][0]; } template float Matrix::Det() const { static_assert(rows == columns, "You can't take the determinant of a non-square matrix."); - // static_assert( - // false, - // "Right now this operation isn't supported for matrices bigger than - // 3x3"); - // Matrix<1, columns> eigenValues{}; - // this->EigenValues(eigenValues); + Matrix MinorMatrix{}; + float determinant{0}; + for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { + // for odd indices the sign is negative + float sign = (column_idx % 2 == 0) ? 1 : -1; + determinant += sign * this->matrix[0][column_idx] * + this->MinorMatrix(MinorMatrix, 0, column_idx).Det(); + } - // float determinant{1}; - // for (uint8_t i{0}; i < columns; i++) { - // determinant *= eigenValues.Get(0, i); - // } - - // return determinant; - return 0; + return determinant; } template @@ -486,6 +475,59 @@ void Matrix::ToString(std::string &stringBuffer) const { } } +template +std::array &Matrix:: +operator[](uint8_t row_index) { + if (row_index > rows - 1) { + return this->matrix[0]; // TODO: We should throw something here instead of + // failing quietly. + } + return this->matrix[row_index]; +} + +template +Matrix &Matrix:: +operator=(const Matrix &other) { + for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { + for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { + this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); + } + } + // return a reference to ourselves so you can chain together these functions + return *this; +} + +template +Matrix Matrix:: +operator+(const Matrix &other) const { + Matrix buffer{}; + this->Add(other, buffer); + return buffer; +} + +template +Matrix Matrix:: +operator-(const Matrix &other) const { + Matrix buffer{}; + this->Sub(other, buffer); + return buffer; +} + +template +Matrix Matrix:: +operator*(const Matrix &other) const { + Matrix buffer{}; + this->Mult(other, buffer); + return buffer; +} + +template +Matrix Matrix::operator*(float scalar) const { + Matrix buffer{}; + this->Mult(scalar, buffer); + return buffer; +} + template template float Matrix::dotProduct(const Matrix<1, vector_size> &vec1, @@ -511,10 +553,10 @@ float Matrix::dotProduct(const Matrix &vec1, } template -void Matrix::zeroMatrix() { +void Matrix::Fill(float value) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - this->matrix[row_idx][column_idx] = 0; + this->matrix[row_idx][column_idx] = value; } } } @@ -522,12 +564,12 @@ void Matrix::zeroMatrix() { template Matrix & Matrix::matrixOfMinors(Matrix &result) const { - Matrix minorMatrix{}; + Matrix MinorMatrix{}; for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - this->minorMatrix(minorMatrix, row_idx, column_idx); - result[row_idx][column_idx] = minorMatrix.Det(); + this->MinorMatrix(MinorMatrix, row_idx, column_idx); + result[row_idx][column_idx] = MinorMatrix.Det(); } } @@ -536,18 +578,20 @@ Matrix::matrixOfMinors(Matrix &result) const { template Matrix & -Matrix::minorMatrix(Matrix &result, +Matrix::MinorMatrix(Matrix &result, uint8_t row_idx, uint8_t column_idx) const { std::array subArray{}; - + uint16_t array_idx{0}; for (uint8_t row_iter{0}; row_iter < rows; row_iter++) { + if (row_iter == row_idx) { + continue; + } for (uint8_t column_iter{0}; column_iter < columns; column_iter++) { - uint16_t array_idx = - static_cast(row_iter) + static_cast(column_iter); - if (row_iter == row_idx || column_iter == column_idx) { + if (column_iter == column_idx) { continue; } subArray[array_idx] = this->Get(row_iter, column_iter); + array_idx++; } } @@ -582,7 +626,7 @@ Matrix::Normalize(Matrix &result) const { if (sum == 0) { // this wouldn't do anything anyways - result.zeroMatrix(); + result.Fill(0); return result; } diff --git a/unit-tests/matrix-tests.cpp b/unit-tests/matrix-tests.cpp index 474afee..10c7c9e 100644 --- a/unit-tests/matrix-tests.cpp +++ b/unit-tests/matrix-tests.cpp @@ -29,6 +29,27 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE(mat3.Get(0, 1) == 0); REQUIRE(mat3.Get(1, 0) == 0); REQUIRE(mat3.Get(1, 1) == 0); + // TODO: what about a matrix of size 255x255? + } + + SECTION("Fill") { + mat1.Fill(0); + REQUIRE(mat1.Get(0, 0) == 0); + REQUIRE(mat1.Get(0, 1) == 0); + REQUIRE(mat1.Get(1, 0) == 0); + REQUIRE(mat1.Get(1, 1) == 0); + + mat2.Fill(100000); + REQUIRE(mat2.Get(0, 0) == 100000); + REQUIRE(mat2.Get(0, 1) == 100000); + REQUIRE(mat2.Get(1, 0) == 100000); + REQUIRE(mat2.Get(1, 1) == 100000); + + mat3.Fill(-20); + REQUIRE(mat3.Get(0, 0) == -20); + REQUIRE(mat3.Get(0, 1) == -20); + REQUIRE(mat3.Get(1, 0) == -20); + REQUIRE(mat3.Get(1, 1) == -20); } SECTION("Addition") { @@ -42,6 +63,14 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE(mat3.Get(0, 1) == 8); REQUIRE(mat3.Get(1, 0) == 10); REQUIRE(mat3.Get(1, 1) == 12); + + // try out addition with overloaded operators + mat3.Fill(0); + mat3 = mat1 + mat2; + REQUIRE(mat3.Get(0, 0) == 6); + REQUIRE(mat3.Get(0, 1) == 8); + REQUIRE(mat3.Get(1, 0) == 10); + REQUIRE(mat3.Get(1, 1) == 12); } SECTION("Subtraction") { @@ -51,6 +80,14 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE(mat3.Get(0, 1) == -4); REQUIRE(mat3.Get(1, 0) == -4); REQUIRE(mat3.Get(1, 1) == -4); + + // try out subtraction with operators + mat3.Fill(0); + mat3 = mat1 - mat2; + REQUIRE(mat3.Get(0, 0) == -4); + REQUIRE(mat3.Get(0, 1) == -4); + REQUIRE(mat3.Get(1, 0) == -4); + REQUIRE(mat3.Get(1, 1) == -4); } SECTION("Multiplication") { @@ -61,7 +98,13 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE(mat3.Get(1, 0) == 43); REQUIRE(mat3.Get(1, 1) == 50); - // try a non-square matrix + // try out multiplication with operators + mat3.Fill(0); + mat3 = mat1 * mat2; + REQUIRE(mat3.Get(0, 0) == 19); + REQUIRE(mat3.Get(0, 1) == 22); + REQUIRE(mat3.Get(1, 0) == 43); + REQUIRE(mat3.Get(1, 1) == 50); } SECTION("Scalar Multiplication") { @@ -94,10 +137,47 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { SECTION("Element Divide") { mat1.ElementDivide(mat2, mat3); - REQUIRE(mat3.Get(0, 0) == 1 / 5); - REQUIRE(mat3.Get(0, 1) == 2 / 6); - REQUIRE(mat3.Get(1, 0) == 3 / 7); - REQUIRE(mat3.Get(1, 1) == 4 / 8); + REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(0.2f, 1e-6f)); + REQUIRE_THAT(mat3.Get(0, 1), Catch::Matchers::WithinRel(0.3333333f, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 0), Catch::Matchers::WithinRel(0.4285714f, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 1), Catch::Matchers::WithinRel(0.5f, 1e-6f)); + } + + SECTION("Minor Matrix") { + // what about matrices of 0,0 or 1,1? + // minor matrix for 2x2 matrix + Matrix<1, 1> minorMat1{}; + mat1.MinorMatrix(minorMat1, 0, 0); + REQUIRE(minorMat1.Get(0, 0) == 4); + mat1.MinorMatrix(minorMat1, 0, 1); + REQUIRE(minorMat1.Get(0, 0) == 3); + mat1.MinorMatrix(minorMat1, 1, 0); + REQUIRE(minorMat1.Get(0, 0) == 2); + mat1.MinorMatrix(minorMat1, 1, 1); + REQUIRE(minorMat1.Get(0, 0) == 1); + + // minor matrix for 3x3 matrix + std::array arr4{1, 2, 3, 4, 5, 6, 7, 8, 9}; + Matrix<3, 3> mat4{arr4}; + Matrix<2, 2> minorMat4{}; + + mat4.MinorMatrix(minorMat4, 0, 0); + REQUIRE(minorMat4.Get(0, 0) == 5); + REQUIRE(minorMat4.Get(0, 1) == 6); + REQUIRE(minorMat4.Get(1, 0) == 8); + REQUIRE(minorMat4.Get(1, 1) == 9); + + mat4.MinorMatrix(minorMat4, 1, 1); + REQUIRE(minorMat4.Get(0, 0) == 1); + REQUIRE(minorMat4.Get(0, 1) == 3); + REQUIRE(minorMat4.Get(1, 0) == 7); + REQUIRE(minorMat4.Get(1, 1) == 9); + + mat4.MinorMatrix(minorMat4, 2, 2); + REQUIRE(minorMat4.Get(0, 0) == 1); + REQUIRE(minorMat4.Get(0, 1) == 2); + REQUIRE(minorMat4.Get(1, 0) == 4); + REQUIRE(minorMat4.Get(1, 1) == 5); } SECTION("Determinant") { @@ -119,7 +199,13 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE_THAT(det5, Catch::Matchers::WithinRel(6.0F, 1e-6f)); } - SECTION("Invert"){}; + SECTION("Invert"){ + // mat1.Invert(mat3); + // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(-2.0F, 1e-6f)); + // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(1.0F, 1e-6f)); + // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(1.5F, 1e-6f)); + // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(-0.5F, 1e-6f)); + }; SECTION("Transpose") { // transpose a square matrix