Made unit tests a little better and fixed matrix multiplication errors for non-square amtrices
Some checks failed
Merge-Checker / build_and_test (pull_request) Failing after 20s

This commit is contained in:
2025-06-02 10:49:16 -04:00
parent 6fdab5be30
commit 37556c7c81
2 changed files with 130 additions and 56 deletions

View File

@@ -105,7 +105,7 @@ Matrix<rows, columns>::Mult(const Matrix<columns, other_columns> &other,
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
// get our row
this->GetRow(row_idx, this_row);
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
for (uint8_t column_idx{0}; column_idx < other_columns; column_idx++) {
// get the other matrix'ss column
other.GetColumn(column_idx, other_column);
@@ -491,6 +491,8 @@ void Matrix<rows, columns>::SetSubMatrix(
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::QRDecomposition(Matrix<rows, columns> &Q,
Matrix<columns, columns> &R) const {
static_assert(columns <= rows, "QR decomposition requires columns <= rows");
// Gram-Schmidt orthogonalization
Matrix<rows, 1> a_col, u, e, proj;
Matrix<rows, 1> q_col;
@@ -512,18 +514,18 @@ void Matrix<rows, columns>::QRDecomposition(Matrix<rows, columns> &Q,
}
float norm = sqrt(Matrix<rows, 1>::DotProduct(u, u));
if (norm < 1e-12f)
if (norm == 0) {
norm = 1e-12f; // avoid div by zero
}
for (uint8_t i = 0; i < rows; ++i)
for (uint8_t i = 0; i < rows; ++i) {
Q[i][k] = u[i][0] / norm;
}
R[k][k] = norm;
}
}
// Compute eigenvalues and eigenvectors by QR iteration
// maxIterations: safety limit, tolerance: stop criteria
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::EigenQR(Matrix<rows, rows> &eigenVectors,
Matrix<rows, 1> &eigenValues,
@@ -531,33 +533,35 @@ void Matrix<rows, columns>::EigenQR(Matrix<rows, rows> &eigenVectors,
float tolerance) const {
static_assert(rows > 1, "Matrix size must be > 1 for QR iteration");
Matrix<rows, rows> A = *this; // copy original matrix
eigenVectors.Identity();
Matrix<rows, rows> Ak = *this; // Copy original matrix
Matrix<rows, rows> QQ{};
QQ.Identity();
for (uint32_t iter = 0; iter < maxIterations; ++iter) {
Matrix<rows, rows> Q, R;
A.QRDecomposition(Q, R);
Ak.QRDecomposition(Q, R);
A = R * Q;
eigenVectors = eigenVectors * Q;
Ak = R * Q;
QQ = QQ * Q;
// Check convergence: off-diagonal norm
float offDiagSum = 0.f;
for (uint8_t i = 0; i < rows; i++) {
for (uint8_t j = 0; j < rows; j++) {
if (i != j) {
offDiagSum += fabs(A[i][j]);
}
float offDiagSum = 0.0f;
for (uint32_t row = 1; row < rows; row++) {
for (uint32_t column = 0; column < row; column++) {
offDiagSum += fabs(Ak[row][column]);
}
}
if (offDiagSum < tolerance) {
break;
}
}
// eigenvalues are the diagonal elements of A
for (uint8_t i = 0; i < rows; ++i)
eigenValues[i][0] = A[i][i];
// Diagonal elements are the eigenvalues
for (uint8_t i = 0; i < rows; i++) {
eigenValues[i][0] = Ak[i][i];
}
eigenVectors = QQ;
}
#endif // MATRIX_H_