From d84664b5670d471537dcecd03622591f4e92739f Mon Sep 17 00:00:00 2001 From: Cynopolis Date: Thu, 5 Jun 2025 15:02:42 -0400 Subject: [PATCH] Improved on old unit tests --- src/Matrix.cpp | 36 +++++---- src/Matrix.hpp | 9 +-- src/Quaternion.h | 141 ++++++++++++++++---------------- unit-tests/matrix-tests.cpp | 157 +++++++++++++++++++++++------------- 4 files changed, 195 insertions(+), 148 deletions(-) diff --git a/src/Matrix.cpp b/src/Matrix.cpp index 68e8d63..a6c1c59 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -13,12 +13,6 @@ #include #include #include -#include - -template -Matrix::Matrix(float value) { - this->Fill(value); -} template Matrix::Matrix(const std::array &array) { @@ -32,6 +26,14 @@ Matrix::Matrix(Args... args) { static_cast(columns)}; std::initializer_list initList{static_cast(args)...}; + // if there is only one value, we actually want to do a fill + if (sizeof...(args) == 1) { + this->Fill(*initList.begin()); + } + static_assert(sizeof...(args) == arraySize || sizeof...(args) == 1, + "You did not provide the right amount of initializers for this " + "matrix size"); + // choose whichever buffer size is smaller for the copy length uint32_t minSize = std::min(arraySize, static_cast(initList.size())); @@ -39,11 +41,13 @@ Matrix::Matrix(Args... args) { } template -void Matrix::Identity() { - this->Fill(0); - for (uint8_t idx{0}; idx < rows; idx++) { - this->matrix[idx * columns + idx] = 1; +Matrix Matrix::Identity() { + Matrix identityMatrix{0}; + uint32_t minDimension = std::min(rows, columns); + for (uint8_t idx{0}; idx < minDimension; idx++) { + identityMatrix[idx][idx] = 1; } + return identityMatrix; } template @@ -564,16 +568,18 @@ void Matrix::EigenQR(Matrix &eigenVectors, uint32_t maxIterations, float tolerance) const { static_assert(rows > 1, "Matrix size must be > 1 for QR iteration"); + static_assert(rows == columns, "Matrix size must be square for QR iteration"); Matrix Ak = *this; // Copy original matrix - Matrix QQ{}; - QQ.Identity(); + Matrix QQ{Matrix::Identity()}; for (uint32_t iter = 0; iter < maxIterations; ++iter) { - Matrix Q, R; - Ak.QRDecomposition(Q, R); + Matrix Q, R, shift; - Ak = R * Q; + // QR shift lets us "attack" the first diagonal to speed up the algorithm + shift = Matrix::Identity() * Ak[rows - 1][rows - 1]; + (Ak - shift).QRDecomposition(Q, R); + Ak = R * Q + shift; QQ = QQ * Q; // Check convergence: off-diagonal norm diff --git a/src/Matrix.hpp b/src/Matrix.hpp index 7ce6940..d4b1798 100644 --- a/src/Matrix.hpp +++ b/src/Matrix.hpp @@ -18,11 +18,6 @@ public: */ Matrix() = default; - /** - * @brief Create a matrix but fill all of its entries with one value - */ - Matrix(float value); - /** * @brief Initialize a matrix with an array */ @@ -39,9 +34,9 @@ public: template Matrix(Args... args); /** - * @brief set the matrix diagonals to 1 and all other values to 0 + * @brief Create an identity matrix */ - void Identity(); + static Matrix Identity(); /** * @brief Set all elements in this to value diff --git a/src/Quaternion.h b/src/Quaternion.h index 8be1486..df0acea 100644 --- a/src/Quaternion.h +++ b/src/Quaternion.h @@ -2,90 +2,89 @@ #define QUATERNION_H_ #include "Matrix.hpp" -class Quaternion : public Matrix<1, 4> -{ +class Quaternion : public Matrix<1, 4> { public: - Quaternion() : Matrix<1, 4>() {} - Quaternion(float fillValue) : Matrix<1, 4>(fillValue) {} - Quaternion(float w, float v1, float v2, float v3) : Matrix<1, 4>(w, v1, v2, v3) {} - Quaternion(const Quaternion &q) : Matrix<1, 4>(q.w, q.v1, q.v2, q.v3) {} - Quaternion(const Matrix<1, 4> &matrix) : Matrix<1, 4>(matrix) {} - Quaternion(const std::array &array) : Matrix<1, 4>(array) {} + Quaternion() : Matrix<1, 4>() {} + Quaternion(float w, float v1, float v2, float v3) + : Matrix<1, 4>(w, v1, v2, v3) {} + Quaternion(const Quaternion &q) : Matrix<1, 4>(q.w, q.v1, q.v2, q.v3) {} + Quaternion(const Matrix<1, 4> &matrix) : Matrix<1, 4>(matrix) {} + Quaternion(const std::array &array) : Matrix<1, 4>(array) {} - /** - * @brief Create a quaternion from an angle and axis - * @param angle The angle to rotate by - * @param axis The axis to rotate around - */ - static Quaternion FromAngleAndAxis(float angle, const Matrix<1, 3> &axis); + /** + * @brief Create a quaternion from an angle and axis + * @param angle The angle to rotate by + * @param axis The axis to rotate around + */ + static Quaternion FromAngleAndAxis(float angle, const Matrix<1, 3> &axis); - /** - * @brief Access the elements of the quaternion - * @param index The index of the element to access - * @return The value of the element at the index - */ - float operator[](uint8_t index) const; + /** + * @brief Access the elements of the quaternion + * @param index The index of the element to access + * @return The value of the element at the index + */ + float operator[](uint8_t index) const; - /** - * @brief Assign one quaternion to another - */ - void operator=(const Quaternion &other); + /** + * @brief Assign one quaternion to another + */ + void operator=(const Quaternion &other); - /** - * @brief Do quaternion multiplication - */ - Quaternion operator*(const Quaternion &other) const; + /** + * @brief Do quaternion multiplication + */ + Quaternion operator*(const Quaternion &other) const; - /** - * @brief Multiply the quaternion by a scalar - */ - Quaternion operator*(float scalar) const; + /** + * @brief Multiply the quaternion by a scalar + */ + Quaternion operator*(float scalar) const; - /** - * @brief Add two quaternions together - * @param other The quaternion to add to this one - * @return The net quaternion - */ - Quaternion operator+(const Quaternion &other) const; + /** + * @brief Add two quaternions together + * @param other The quaternion to add to this one + * @return The net quaternion + */ + Quaternion operator+(const Quaternion &other) const; - /** - * @brief Q_Mult a quaternion by another quaternion - * @param other The quaternion to rotate by - * @param buffer The buffer to store the result in - * @return A reference to the buffer - */ - Quaternion &Q_Mult(const Quaternion &other, Quaternion &buffer) const; + /** + * @brief Q_Mult a quaternion by another quaternion + * @param other The quaternion to rotate by + * @param buffer The buffer to store the result in + * @return A reference to the buffer + */ + Quaternion &Q_Mult(const Quaternion &other, Quaternion &buffer) const; - /** - * @brief Rotate a quaternion by this quaternion - * @param other The quaternion to rotate - * @param buffer The buffer to store the result in - * - */ - Quaternion &Rotate(Quaternion &other, Quaternion &buffer) const; + /** + * @brief Rotate a quaternion by this quaternion + * @param other The quaternion to rotate + * @param buffer The buffer to store the result in + * + */ + Quaternion &Rotate(Quaternion &other, Quaternion &buffer) const; - /** - * @brief Normalize the quaternion to a magnitude of 1 - */ - void Normalize(); + /** + * @brief Normalize the quaternion to a magnitude of 1 + */ + void Normalize(); - /** - * @brief Convert the quaternion to a rotation matrix - * @return The rotation matrix - */ - Matrix<3, 3> ToRotationMatrix() const; + /** + * @brief Convert the quaternion to a rotation matrix + * @return The rotation matrix + */ + Matrix<3, 3> ToRotationMatrix() const; - /** - * @brief Convert the quaternion to an Euler angle representation - * @return The Euler angle representation of the quaternion - */ - Matrix<3, 1> ToEulerAngle() const; + /** + * @brief Convert the quaternion to an Euler angle representation + * @return The Euler angle representation of the quaternion + */ + Matrix<3, 1> ToEulerAngle() const; - // Give people an easy way to access the elements - float &w{matrix[0]}; - float &v1{matrix[1]}; - float &v2{matrix[2]}; - float &v3{matrix[3]}; + // Give people an easy way to access the elements + float &w{matrix[0]}; + float &v1{matrix[1]}; + float &v2{matrix[2]}; + float &v3{matrix[3]}; }; #endif // QUATERNION_H_ \ No newline at end of file diff --git a/unit-tests/matrix-tests.cpp b/unit-tests/matrix-tests.cpp index 2a6d86c..2cf8776 100644 --- a/unit-tests/matrix-tests.cpp +++ b/unit-tests/matrix-tests.cpp @@ -10,41 +10,61 @@ #include #include +// Helper functions +template +float matrixSum(const Matrix &matrix) { + float sum = 0; + for (uint32_t i = 0; i < rows * columns; i++) { + float number = matrix.ToArray()[i]; + sum += number * number; + } + return std::sqrt(sum); +} + +template +void printLabeledMatrix(const std::string &label, + const Matrix &matrix) { + std::string strBuf = ""; + matrix.ToString(strBuf); + std::cout << label << ":\n" << strBuf << std::endl; +} + +TEST_CASE("Initialization", "Matrix") { + SECTION("Array Initialization") { + std::array arr2{5, 6, 7, 8}; + Matrix<2, 2> mat2{arr2}; + // array initialization + REQUIRE(mat2.Get(0, 0) == 5); + REQUIRE(mat2.Get(0, 1) == 6); + REQUIRE(mat2.Get(1, 0) == 7); + REQUIRE(mat2.Get(1, 1) == 8); + } + + SECTION("Argument Pack Initialization") { + Matrix<2, 2> mat1{1, 2, 3, 4}; + // template pack initialization + REQUIRE(mat1.Get(0, 0) == 1); + REQUIRE(mat1.Get(0, 1) == 2); + REQUIRE(mat1.Get(1, 0) == 3); + REQUIRE(mat1.Get(1, 1) == 4); + } + + SECTION("Single Argument Pack Initialization") { + Matrix<2, 2> mat1{2}; + // template pack initialization + REQUIRE(mat1.Get(0, 0) == 2); + REQUIRE(mat1.Get(0, 1) == 2); + REQUIRE(mat1.Get(1, 0) == 2); + REQUIRE(mat1.Get(1, 1) == 2); + } +} + TEST_CASE("Elementary Matrix Operations", "Matrix") { std::array arr2{5, 6, 7, 8}; Matrix<2, 2> mat1{1, 2, 3, 4}; Matrix<2, 2> mat2{arr2}; Matrix<2, 2> mat3{}; - SECTION("Initialization") { - // array initialization - REQUIRE(mat1.Get(0, 0) == 1); - REQUIRE(mat1.Get(0, 1) == 2); - REQUIRE(mat1.Get(1, 0) == 3); - REQUIRE(mat1.Get(1, 1) == 4); - - // empty initialization - REQUIRE(mat3.Get(0, 0) == 0); - REQUIRE(mat3.Get(0, 1) == 0); - REQUIRE(mat3.Get(1, 0) == 0); - REQUIRE(mat3.Get(1, 1) == 0); - - // template pack initialization - REQUIRE(mat2.Get(0, 0) == 5); - REQUIRE(mat2.Get(0, 1) == 6); - REQUIRE(mat2.Get(1, 0) == 7); - REQUIRE(mat2.Get(1, 1) == 8); - - // large matrix - Matrix<255, 255> mat6{}; - mat6.Fill(4); - for (uint8_t row{0}; row < 255; row++) { - for (uint8_t column{0}; column < 255; column++) { - REQUIRE(mat6.Get(row, column) == 4); - } - } - } - SECTION("Fill") { mat1.Fill(0); REQUIRE(mat1.Get(0, 0) == 0); @@ -66,10 +86,6 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { } SECTION("Addition") { - std::string strBuf1 = ""; - mat1.ToString(strBuf1); - std::cout << "Matrix 1:\n" << strBuf1 << std::endl; - mat1.Add(mat2, mat3); REQUIRE(mat3.Get(0, 0) == 6); @@ -363,18 +379,58 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { } } -template -float matrixSum(const Matrix &matrix) { - float sum = 0; - for (uint32_t i = 0; i < rows * columns; i++) { - float number = matrix.ToArray()[i]; - sum += number * number; +TEST_CASE("Identity Matrix", "Matrix") { + SECTION("Square Matrix") { + Matrix<5, 5> matrix = Matrix<5, 5>::Identity(); + uint32_t oneColumnIndex{0}; + for (uint32_t row = 0; row < 5; row++) { + for (uint32_t column = 0; column < 5; column++) { + float value = matrix[row][column]; + if (oneColumnIndex == column) { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(1.0f, 1e-6f)); + } else { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(0.0f, 1e-6f)); + } + } + oneColumnIndex++; + } + } + + SECTION("Wide Matrix") { + Matrix<2, 5> matrix = Matrix<2, 5>::Identity(); + + uint32_t oneColumnIndex{0}; + for (uint32_t row = 0; row < 2; row++) { + for (uint32_t column = 0; column < 5; column++) { + float value = matrix[row][column]; + if (oneColumnIndex == column && row < 3) { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(1.0f, 1e-6f)); + } else { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(0.0f, 1e-6f)); + } + } + oneColumnIndex++; + } + } + + SECTION("Tall Matrix") { + Matrix<5, 2> matrix = Matrix<5, 2>::Identity(); + uint32_t oneColumnIndex{0}; + for (uint32_t row = 0; row < 5; row++) { + for (uint32_t column = 0; column < 2; column++) { + float value = matrix[row][column]; + if (oneColumnIndex == column) { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(1.0f, 1e-6f)); + } else { + REQUIRE_THAT(value, Catch::Matchers::WithinRel(0.0f, 1e-6f)); + } + } + oneColumnIndex++; + } } - return std::sqrt(sum); } // TODO: Add test for scalar division - TEST_CASE("Euclidean Norm", "Matrix") { SECTION("2x2 Normalize") { @@ -469,18 +525,10 @@ TEST_CASE("QR Decompositions", "Matrix") { SECTION("3x3 QRDecomposition") { // this symmetrix tridiagonal matrix is well behaved for testing Matrix<3, 3> A{1, 2, 3, 4, 5, 6, 7, 8, 9}; - uint32_t matrixRank = 2; Matrix<3, 3> Q{}, R{}; A.QRDecomposition(Q, R); - std::string strBuf1 = ""; - Q.ToString(strBuf1); - std::cout << "Q:\n" << strBuf1 << std::endl; - strBuf1 = ""; - R.ToString(strBuf1); - std::cout << "R:\n" << strBuf1 << std::endl; - // Check that Q * R ≈ A Matrix<3, 3> QR{}; QR = Q * R; @@ -491,13 +539,13 @@ TEST_CASE("QR Decompositions", "Matrix") { } // Check that Qᵀ * Q ≈ I - // In this case the A matrix is only rank 2, so the identity matrix given by - // Qᵀ * Q is actually only going to be 2x2. + // This MUST be true even if the rank of A is 2 because without this, + // calculating eigenvalues/vectors will not work. Matrix<3, 3> Qt = Q.Transpose(); Matrix<3, 3> QtQ{}; QtQ = Qt * Q; - for (int i = 0; i < matrixRank; ++i) { - for (int j = 0; j < matrixRank; ++j) { + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { if (i == j) REQUIRE_THAT(QtQ[i][j], Catch::Matchers::WithinRel(1.0f, 1e-4f)); else @@ -506,8 +554,7 @@ TEST_CASE("QR Decompositions", "Matrix") { } // Optional: Check R is upper triangular - // The matrix's rank is only 2 so the last row will not be triangular - for (int i = 1; i < matrixRank; ++i) { + for (int i = 1; i < 3; ++i) { for (int j = 0; j < i; ++j) { REQUIRE(std::fabs(R[i][j]) < 1e-4f); }