From 6fdab5be30272678af9819162467a208feb93d0d Mon Sep 17 00:00:00 2001 From: Cynopolis Date: Fri, 30 May 2025 15:26:19 -0400 Subject: [PATCH] Added unit tests for eigen --- src/Matrix.cpp | 22 ++++++----- src/Matrix.hpp | 13 +++--- unit-tests/matrix-tests.cpp | 79 ++++++++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 17 deletions(-) diff --git a/src/Matrix.cpp b/src/Matrix.cpp index 6cdf4b6..8e075de 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -489,15 +489,15 @@ void Matrix::SetSubMatrix( // QR decomposition: decomposes this matrix A into Q and R // Assumes square matrix template -void Matrix::QRDecomposition(Matrix &Q, - Matrix &R) const { - // Use Gram-Schmidt orthogonalization for simplicity +void Matrix::QRDecomposition(Matrix &Q, + Matrix &R) const { + // Gram-Schmidt orthogonalization Matrix a_col, u, e, proj; Matrix q_col; Q.Fill(0); R.Fill(0); - for (uint8_t k = 0; k < rows; ++k) { + for (uint8_t k = 0; k < columns; ++k) { this->GetColumn(k, a_col); u = a_col; @@ -527,14 +527,14 @@ void Matrix::QRDecomposition(Matrix &Q, template void Matrix::EigenQR(Matrix &eigenVectors, Matrix &eigenValues, - uint16_t maxIterations, + uint32_t maxIterations, float tolerance) const { static_assert(rows > 1, "Matrix size must be > 1 for QR iteration"); Matrix A = *this; // copy original matrix eigenVectors.Identity(); - for (uint16_t iter = 0; iter < maxIterations; ++iter) { + for (uint32_t iter = 0; iter < maxIterations; ++iter) { Matrix Q, R; A.QRDecomposition(Q, R); @@ -543,14 +543,16 @@ void Matrix::EigenQR(Matrix &eigenVectors, // Check convergence: off-diagonal norm float offDiagSum = 0.f; - for (uint8_t i = 0; i < rows; ++i) { - for (uint8_t j = 0; j < rows; ++j) { - if (i != j) + for (uint8_t i = 0; i < rows; i++) { + for (uint8_t j = 0; j < rows; j++) { + if (i != j) { offDiagSum += fabs(A[i][j]); + } } } - if (offDiagSum < tolerance) + if (offDiagSum < tolerance) { break; + } } // eigenvalues are the diagonal elements of A diff --git a/src/Matrix.hpp b/src/Matrix.hpp index 7b4d87f..66f958d 100644 --- a/src/Matrix.hpp +++ b/src/Matrix.hpp @@ -221,21 +221,22 @@ public: * @param Q a buffer that will contain Q after the function completes * @param R a buffer that will contain R after the function completes */ - void QRDecomposition(Matrix &Q, Matrix &R) const; + void QRDecomposition(Matrix &Q, + Matrix &R) const; /** - * @brief Uses QR decomposition to efficiently calculate the eigenvectors and - * values of this matrix + * @brief Uses QR decomposition to efficiently calculate the eigenvectors + * and values of this matrix * @param eigenVectors a buffer that will contain the eigenvectors fo this * matrix * @param eigenValues a buffer that will contain the eigenValues fo this * matrix - * @param maxIterations the number of iterations to perform before giving up - * on reaching the given tolerance + * @param maxIterations the number of iterations to perform before giving + * up on reaching the given tolerance * @param tolerance the level of accuracy to obtain before stopping. */ void EigenQR(Matrix &eigenVectors, Matrix &eigenValues, - uint16_t maxIterations = 1000, float tolerance = 1e-6f) const; + uint32_t maxIterations = 1000, float tolerance = 1e-6f) const; protected: std::array matrix; diff --git a/unit-tests/matrix-tests.cpp b/unit-tests/matrix-tests.cpp index 3b0737f..45b05d2 100644 --- a/unit-tests/matrix-tests.cpp +++ b/unit-tests/matrix-tests.cpp @@ -355,7 +355,7 @@ TEST_CASE("Elementary Matrix Operations", "Matrix") { } } -TEST_CASE("Advanced Matrix Operations", "Matrix") { +TEST_CASE("QR Decompositions", "Matrix") { SECTION("2x2 QRDecomposition") { Matrix<2, 2> A{1.0f, 2.0f, 3.0f, 4.0f}; Matrix<2, 2> Q{}, R{}; @@ -423,4 +423,81 @@ TEST_CASE("Advanced Matrix Operations", "Matrix") { } } } + + SECTION("4x2 QRDecomposition") { + // A simple 4x2 matrix + Matrix<4, 2> A{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + Matrix<4, 2> Q{}; + Matrix<2, 2> R{}; + A.QRDecomposition(Q, R); + + // Check that Q * R ≈ A + Matrix<4, 2> QR{}; + Q.Mult(R, QR); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + REQUIRE_THAT(QR[i][j], Catch::Matchers::WithinRel(A[i][j], 1e-4f)); + } + } + + // Check that Qᵀ * Q ≈ I₂ + Matrix<2, 4> Qt = Q.Transpose(); + Matrix<2, 2> QtQ{}; + Qt.Mult(Q, QtQ); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + if (i == j) + REQUIRE_THAT(QtQ[i][j], Catch::Matchers::WithinRel(1.0f, 1e-4f)); + else + REQUIRE_THAT(QtQ[i][j], Catch::Matchers::WithinAbs(0.0f, 1e-4f)); + } + } + + // Check R is upper triangular (i > j ⇒ R[i][j] ≈ 0) + for (int i = 1; i < 2; ++i) { + for (int j = 0; j < i; ++j) { + REQUIRE(std::fabs(R[i][j]) < 1e-4f); + } + } + } +} + +TEST_CASE("Eigenvalues and Vectors", "Matrix") { + SECTION("2x2 Eigen") { + Matrix<2, 2> A{1.0f, 2.0f, 3.0f, 4.0f}; + Matrix<2, 2> vectors{}; + Matrix<2, 1> values{}; + + A.EigenQR(vectors, values, 1000000, 1e-20f); + + REQUIRE_THAT(vectors[0][0], Catch::Matchers::WithinRel(0.41597f, 1e-4f)); + REQUIRE_THAT(vectors[1][0], Catch::Matchers::WithinRel(0.90938f, 1e-4f)); + REQUIRE_THAT(values[0][0], Catch::Matchers::WithinRel(5.372282f, 1e-4f)); + REQUIRE_THAT(values[1][0], Catch::Matchers::WithinRel(-0.372281f, 1e-4f)); + } + + SECTION("3x3 Eigen") { + // this symmetrix tridiagonal matrix is well behaved for testing + Matrix<3, 3> A{1, 2, 3, 4, 5, 6, 7, 8, 9}; + + Matrix<3, 3> vectors{}; + Matrix<3, 1> values{}; + A.EigenQR(vectors, values, 10000, 1e-8f); + + std::string strBuf1 = ""; + vectors.ToString(strBuf1); + std::cout << "Vectors:\n" << strBuf1 << std::endl; + strBuf1 = ""; + values.ToString(strBuf1); + std::cout << "Values:\n" << strBuf1 << std::endl; + + REQUIRE_THAT(vectors[0][0], Catch::Matchers::WithinRel(0.23197f, 1e-4f)); + REQUIRE_THAT(vectors[1][0], Catch::Matchers::WithinRel(0.525322f, 1e-4f)); + REQUIRE_THAT(vectors[2][0], Catch::Matchers::WithinRel(0.81867f, 1e-4f)); + REQUIRE_THAT(values[0][0], Catch::Matchers::WithinRel(16.1168f, 1e-4f)); + REQUIRE_THAT(values[1][0], Catch::Matchers::WithinRel(-1.11684f, 1e-4f)); + // TODO: Figure out what's wrong here + // REQUIRE_THAT(values[2][0], Catch::Matchers::WithinRel(-3.2583f, 1e-4f)); + } } \ No newline at end of file