From 9cbaeb2c2770b9e02e58d535fee1e41c9655afe0 Mon Sep 17 00:00:00 2001 From: Quinn Henthorne Date: Wed, 11 Dec 2024 09:09:08 -0500 Subject: [PATCH] Got matrix multiplication working --- .vscode/settings.json | 3 +- Matrix.hpp | 92 +++++++++++++++++++------------------ unit-tests/matrix-tests.cpp | 65 +++++++++++++++++++++----- 3 files changed, 103 insertions(+), 57 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 8344eb9..ad6aaf8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -70,5 +70,6 @@ "thread": "cpp", "typeinfo": "cpp", "variant": "cpp" - } + }, + "clangd.enable": false } \ No newline at end of file diff --git a/Matrix.hpp b/Matrix.hpp index 9ad0b61..3fb4848 100644 --- a/Matrix.hpp +++ b/Matrix.hpp @@ -13,7 +13,7 @@ public: /** * @brief Initialize a matrix with an array */ - Matrix(const std::array &array); + Matrix(const std::array &array); // TODO: Figure out how to do this /** @@ -37,8 +37,8 @@ public: * @param result A buffer to store the result into * @note there is no problem if result == this */ - void Subtract(const Matrix &other, - Matrix &result) const; + void Sub(const Matrix &other, + Matrix &result) const; /** * @brief Matrix multiply the two matrices @@ -46,8 +46,8 @@ public: * @param result A buffer to store the result into */ template - void Multiply(const Matrix &other, - Matrix &result) const; + void Mult(const Matrix &other, + Matrix &result) const; /** * @brief Multiply the matrix by a scalar @@ -55,7 +55,7 @@ public: * @param result A buffer to store the result into * @note there is no problem if result == this */ - void Multiply(float scalar, Matrix &result) const; + void Mult(float scalar, Matrix &result) const; /** * @brief Invert this matrix @@ -117,13 +117,13 @@ public: * @brief get the specified row of the matrix returned as a reference to the * internal array */ - std::array &operator[](uint8_t row_index){ + std::array &operator[](uint8_t row_index) { return this->matrix[row_index]; } - Matrix &operator=(const Matrix &other){ - for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ - for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ + Matrix &operator=(const Matrix &other) { + for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { + for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); } } @@ -153,15 +153,15 @@ public: */ constexpr uint8_t GetColumnSize() { return columns; } - void ToString(std::string & stringBuffer) const; + void ToString(std::string &stringBuffer) const; private: /** * @brief take the dot product of the two vectors */ template - float dotProduct(const Matrix &vec1, - const Matrix &vec2); + static float dotProduct(const Matrix<1, vector_size> &vec1, + const Matrix<1, vector_size> &vec2); /** * @brief Set all elements in this matrix to zero @@ -183,17 +183,19 @@ private: constexpr bool isSquare() { return rows == columns; } - void setMatrixToArray(const std::array & array); + void setMatrixToArray(const std::array &array); std::array, rows> matrix; }; template -void Matrix::setMatrixToArray(const std::array & array){ +void Matrix::setMatrixToArray( + const std::array &array) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { uint16_t array_idx = - static_cast(row_idx) * static_cast(columns) + static_cast(column_idx); + static_cast(row_idx) * static_cast(columns) + + static_cast(column_idx); if (array_idx < array.size()) { this->matrix[row_idx][column_idx] = array[array_idx]; } else { @@ -208,14 +210,14 @@ template Matrix::Matrix() { } template -Matrix::Matrix(const std::array &array) { +Matrix::Matrix(const std::array &array) { this->setMatrixToArray(array); } // template // template // Matrix::Matrix(Args&&... args){ - + // // Initialize a std::array with the arguments // if(typeid(args) == typeid(std::array)){ // this->setMatrixToArray(args); @@ -226,7 +228,7 @@ Matrix::Matrix(const std::array &array) { // // now store the array in our internal matrix // this->setMatrixToArray(values); // } - + // } template @@ -234,48 +236,49 @@ void Matrix::Add(const Matrix &other, Matrix &result) const { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx); + result[row_idx][column_idx] = + this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx); } } } template -void Matrix::Subtract(const Matrix &other, - Matrix &result) const { +void Matrix::Sub(const Matrix &other, + Matrix &result) const { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { - result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx); + result[row_idx][column_idx] = + this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx); } } } template template -void Matrix::Multiply( - const Matrix &other, - Matrix &result) const { +void Matrix::Mult(const Matrix &other, + Matrix &result) const { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { // get our row - Matrix this_row; + Matrix<1, columns> this_row; this->GetRow(row_idx, this_row); // get the other matrices column - Matrix<1, columns> other_column; + Matrix other_column; other.GetColumn(column_idx, other_column); // transpose the other matrix's column - Matrix other_column_t; + Matrix<1, rows> other_column_t; other_column.Transpose(other_column_t); // the result's index is equal to the dot product of these two vectors result[row_idx][column_idx] = - this->dotProduct(this_row, other_column_t); + Matrix::dotProduct(this_row, other_column_t); } } } template -void Matrix::Multiply(float scalar, - Matrix &result) const { +void Matrix::Mult(float scalar, + Matrix &result) const { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { result[row_idx][column_idx] = this->Get(row_idx, column_idx) * scalar; @@ -311,7 +314,7 @@ void Matrix::Invert(Matrix &result) const { float determinant = this->Det(); // scale the result by 1/determinant and we have our answer - result.Multiply(1 / determinant); + result.Mult(1 / determinant); } template @@ -327,7 +330,7 @@ template void Matrix::Square(Matrix &result) const { static_assert(this->isSquare(), "You can't square an non-square matrix."); - this->Multiply(this, result); + this->Mult(this, result); } template @@ -424,7 +427,7 @@ void Matrix::ElementDivide(const Matrix &other, template float Matrix::Get(uint8_t row_index, - uint8_t column_index) const { + uint8_t column_index) const { return this->matrix[row_index][column_index]; } @@ -438,17 +441,17 @@ template void Matrix::GetColumn(uint8_t column_index, Matrix &column) const { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { - column[0][column_index] = this->Get(row_idx, column_index); + column[row_idx][0] = this->Get(row_idx, column_index); } } template -void Matrix::ToString(std::string & stringBuffer) const{ - for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ +void Matrix::ToString(std::string &stringBuffer) const { + for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { stringBuffer += "|"; - for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ + for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { stringBuffer += std::to_string(this->matrix[row_idx][column_idx]); - if(column_idx != columns - 1){ + if (column_idx != columns - 1) { stringBuffer += "\t"; } } @@ -458,11 +461,11 @@ void Matrix::ToString(std::string & stringBuffer) const{ template template -float Matrix::dotProduct(const Matrix &vec1, - const Matrix &vec2) { +float Matrix::dotProduct(const Matrix<1, vector_size> &vec1, + const Matrix<1, vector_size> &vec2) { float sum{0}; for (uint8_t i{0}; i < vector_size; i++) { - sum += vec1.Get(i, 0) * vec2.Get(i, 0); + sum += vec1.Get(0, i) * vec2.Get(0, i); } return sum; @@ -516,8 +519,7 @@ void Matrix::adjugate(Matrix &result) const { for (uint8_t column_iter{0}; column_iter < columns; column_iter++) { float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1; sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1; - result[row_iter][column_iter] = - this->Get(row_iter, column_iter) * sign; + result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign; } } } diff --git a/unit-tests/matrix-tests.cpp b/unit-tests/matrix-tests.cpp index dd53aae..25d2d0f 100644 --- a/unit-tests/matrix-tests.cpp +++ b/unit-tests/matrix-tests.cpp @@ -8,21 +8,64 @@ #include #include -TEST_CASE("Matrix Addition", "Matrix::Add") { +TEST_CASE("Elementary Matrix Operations", "Matrix::Add") { std::array arr1{1, 2, 3, 4}; std::array arr2{5, 6, 7, 8}; Matrix<2, 2> mat1{arr1}; Matrix<2, 2> mat2{arr2}; - - std::string strBuf1 = ""; - mat1.ToString(strBuf1); - std::cout << strBuf1 << std::endl; - Matrix<2, 2> mat3{}; - mat1.Add(mat2, mat3); - REQUIRE(mat3.Get(0, 0) == 6); - REQUIRE(mat3.Get(0, 1) == 8); - REQUIRE(mat3.Get(1, 0) == 10); - REQUIRE(mat3.Get(1, 1) == 12); + SECTION("Initialization") { + // array initialization + REQUIRE(mat1.Get(0, 0) == 1); + REQUIRE(mat1.Get(0, 1) == 2); + REQUIRE(mat1.Get(1, 0) == 3); + REQUIRE(mat1.Get(1, 1) == 4); + + // empty initialization + REQUIRE(mat3.Get(0, 0) == 0); + REQUIRE(mat3.Get(0, 1) == 0); + REQUIRE(mat3.Get(1, 0) == 0); + REQUIRE(mat3.Get(1, 1) == 0); + } + + SECTION("Addition") { + std::string strBuf1 = ""; + mat1.ToString(strBuf1); + std::cout << "Matrix 1:\n" << strBuf1 << std::endl; + + mat1.Add(mat2, mat3); + + REQUIRE(mat3.Get(0, 0) == 6); + REQUIRE(mat3.Get(0, 1) == 8); + REQUIRE(mat3.Get(1, 0) == 10); + REQUIRE(mat3.Get(1, 1) == 12); + } + + SECTION("Subtraction") { + mat1.Sub(mat2, mat3); + + REQUIRE(mat3.Get(0, 0) == -4); + REQUIRE(mat3.Get(0, 1) == -4); + REQUIRE(mat3.Get(1, 0) == -4); + REQUIRE(mat3.Get(1, 1) == -4); + } + + SECTION("Multiplication") { + mat1.Mult(mat2, mat3); + + REQUIRE(mat3.Get(0, 0) == 19); + REQUIRE(mat3.Get(0, 1) == 22); + REQUIRE(mat3.Get(1, 0) == 43); + REQUIRE(mat3.Get(1, 1) == 50); + } + + SECTION("Scalar Multiplication") { + mat1.Mult(2, mat3); + + REQUIRE(mat3.Get(0, 0) == 2); + REQUIRE(mat3.Get(0, 1) == 4); + REQUIRE(mat3.Get(1, 0) == 6); + REQUIRE(mat3.Get(1, 1) == 8); + } } \ No newline at end of file