From 5bbdefa4cf07b4091bca9060e3209ad2cc175136 Mon Sep 17 00:00:00 2001 From: Quinn Henthorne Date: Fri, 13 Dec 2024 10:04:06 -0500 Subject: [PATCH] Fixed matrix inversion --- Matrix.hpp | 55 +++++++++++++++---------------------- unit-tests/matrix-tests.cpp | 43 +++++++++++++++++++---------- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/Matrix.hpp b/Matrix.hpp index 8c1f76c..1d86828 100644 --- a/Matrix.hpp +++ b/Matrix.hpp @@ -74,12 +74,6 @@ public: Matrix &Mult(float scalar, Matrix &result) const; - /** - * @brief Square this matrix - * @param result A buffer to store the result into - */ - Matrix &Square(Matrix &result) const; - /** * @brief Element-wise multiply the two matrices * @param other the other matrix to multiply into this one @@ -108,6 +102,8 @@ public: */ float Det() const; + Matrix &MatrixOfMinors(Matrix &result) const; + /** * @brief Invert this matrix * @param result A buffer to store the result into @@ -196,8 +192,6 @@ private: static float dotProduct(const Matrix &vec1, const Matrix &vec2); - Matrix &matrixOfMinors(Matrix &result) const; - Matrix &adjugate(Matrix &result) const; void setMatrixToArray(const std::array &array); @@ -290,16 +284,18 @@ template Matrix & Matrix::Mult(const Matrix &other, Matrix &result) const { + // allocate some buffers for all of our dot products + Matrix<1, columns> this_row; + Matrix other_column; + Matrix<1, rows> other_column_t; + for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { + // get our row + this->GetRow(row_idx, this_row); for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - // get our row - Matrix<1, columns> this_row; - this->GetRow(row_idx, this_row); - // get the other matrices column - Matrix other_column; + // get the other matrix'ss column other.GetColumn(column_idx, other_column); // transpose the other matrix's column - Matrix<1, rows> other_column_t; other_column.Transpose(other_column_t); // the result's index is equal to the dot product of these two vectors @@ -334,7 +330,7 @@ Matrix::Invert(Matrix &result) const { // unfortunately we can't calculate this at compile time so we'll just reurn // zeros float determinant{this->Det()}; - if (determinant < 0) { + if (determinant == 0) { // you can't invert a matrix with a negative determinant result.Fill(0); return result; @@ -346,13 +342,14 @@ Matrix::Invert(Matrix &result) const { // calculate the matrix of minors Matrix minors{}; - this->matrixOfMinors(minors); + this->MatrixOfMinors(minors); // now adjugate the matrix and save it in our output minors.adjugate(result); // scale the result by 1/determinant and we have our answer - result.Mult(1 / determinant, result); + result = result * (1 / determinant); + // result.Mult(1 / determinant, result); return result; } @@ -369,20 +366,10 @@ Matrix::Transpose(Matrix &result) const { return result; } -template -Matrix & -Matrix::Square(Matrix &result) const { - // TODO: Because template requirements are checked before static_assert, this - // never throws an error and fails at the Mult call instead. - static_assert(rows == columns, "You can't square a non-square matrix."); - - this->Mult(*this, result); - - return result; -} - // explicitly define the determinant for a 2x2 matrix because it is definitely // the fastest way to calculate a 2x2 matrix determinant +template <> float Matrix<0, 0>::Det() const { return 1e+6; } +template <> float Matrix<1, 1>::Det() const { return this->matrix[0][0]; } template <> float Matrix<2, 2>::Det() const { return this->matrix[0][0] * this->matrix[1][1] - this->matrix[0][1] * this->matrix[1][0]; @@ -392,6 +379,7 @@ template float Matrix::Det() const { static_assert(rows == columns, "You can't take the determinant of a non-square matrix."); + Matrix MinorMatrix{}; float determinant{0}; for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { @@ -436,7 +424,8 @@ template float Matrix::Get(uint8_t row_index, uint8_t column_index) const { if (row_index > rows - 1 || column_index > columns - 1) { - return 0; // TODO: We should throw something here instead of failing quietly + return 1e+10; // TODO: We should throw something here instead of failing + // quietly } return this->matrix[row_index][column_index]; } @@ -563,7 +552,7 @@ void Matrix::Fill(float value) { template Matrix & -Matrix::matrixOfMinors(Matrix &result) const { +Matrix::MatrixOfMinors(Matrix &result) const { Matrix MinorMatrix{}; for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { @@ -606,7 +595,7 @@ Matrix::adjugate(Matrix &result) const { for (uint8_t column_iter{0}; column_iter < columns; column_iter++) { float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1; sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1; - result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign; + result[column_iter][row_iter] = this->Get(row_iter, column_iter) * sign; } } @@ -626,7 +615,7 @@ Matrix::Normalize(Matrix &result) const { if (sum == 0) { // this wouldn't do anything anyways - result.Fill(0); + result.Fill(1e+6); return result; } diff --git a/unit-tests/matrix-tests.cpp b/unit-tests/matrix-tests.cpp index 10c7c9e..714e4da 100644 --- a/unit-tests/matrix-tests.cpp +++ b/unit-tests/matrix-tests.cpp @@ -116,15 +116,6 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE(mat3.Get(1, 1) == 8); } - SECTION("Squaring") { - mat1.Square(mat3); - - REQUIRE(mat3.Get(0, 0) == 7); - REQUIRE(mat3.Get(0, 1) == 10); - REQUIRE(mat3.Get(1, 0) == 15); - REQUIRE(mat3.Get(1, 1) == 22); - } - SECTION("Element Multiply") { mat1.ElementMultiply(mat2, mat3); @@ -199,12 +190,34 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { REQUIRE_THAT(det5, Catch::Matchers::WithinRel(6.0F, 1e-6f)); } - SECTION("Invert"){ - // mat1.Invert(mat3); - // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(-2.0F, 1e-6f)); - // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(1.0F, 1e-6f)); - // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(1.5F, 1e-6f)); - // REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(-0.5F, 1e-6f)); + SECTION("Matrix of Minors") { + mat1.MatrixOfMinors(mat3); + REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(4.0F, 1e-6f)); + REQUIRE_THAT(mat3.Get(0, 1), Catch::Matchers::WithinRel(3.0F, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 0), Catch::Matchers::WithinRel(2.0F, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 1), Catch::Matchers::WithinRel(1.0F, 1e-6f)); + + std::array arr4{1, 2, 3, 4, 5, 6, 7, 8, 9}; + Matrix<3, 3> mat4{arr4}; + Matrix<3, 3> mat5{0}; + mat4.MatrixOfMinors(mat5); + REQUIRE_THAT(mat5.Get(0, 0), Catch::Matchers::WithinRel(-3.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(0, 1), Catch::Matchers::WithinRel(-6.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(0, 2), Catch::Matchers::WithinRel(-3.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(1, 0), Catch::Matchers::WithinRel(-6.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(1, 1), Catch::Matchers::WithinRel(-12.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(1, 2), Catch::Matchers::WithinRel(-6.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(2, 0), Catch::Matchers::WithinRel(-3.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(2, 1), Catch::Matchers::WithinRel(-6.0F, 1e-6f)); + REQUIRE_THAT(mat5.Get(2, 2), Catch::Matchers::WithinRel(-3.0F, 1e-6f)); + } + + SECTION("Invert") { + mat1.Invert(mat3); + REQUIRE_THAT(mat3.Get(0, 0), Catch::Matchers::WithinRel(-2.0F, 1e-6f)); + REQUIRE_THAT(mat3.Get(0, 1), Catch::Matchers::WithinRel(1.0F, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 0), Catch::Matchers::WithinRel(1.5F, 1e-6f)); + REQUIRE_THAT(mat3.Get(1, 1), Catch::Matchers::WithinRel(-0.5F, 1e-6f)); }; SECTION("Transpose") {