Got matrix multiplication working

This commit is contained in:
Quinn Henthorne
2024-12-11 09:09:08 -05:00
parent 1fb211912d
commit 9cbaeb2c27
3 changed files with 103 additions and 57 deletions

View File

@@ -70,5 +70,6 @@
"thread": "cpp", "thread": "cpp",
"typeinfo": "cpp", "typeinfo": "cpp",
"variant": "cpp" "variant": "cpp"
} },
"clangd.enable": false
} }

View File

@@ -13,7 +13,7 @@ public:
/** /**
* @brief Initialize a matrix with an array * @brief Initialize a matrix with an array
*/ */
Matrix(const std::array<float, rows*columns> &array); Matrix(const std::array<float, rows * columns> &array);
// TODO: Figure out how to do this // TODO: Figure out how to do this
/** /**
@@ -37,7 +37,7 @@ public:
* @param result A buffer to store the result into * @param result A buffer to store the result into
* @note there is no problem if result == this * @note there is no problem if result == this
*/ */
void Subtract(const Matrix<rows, columns> &other, void Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const; Matrix<rows, columns> &result) const;
/** /**
@@ -46,7 +46,7 @@ public:
* @param result A buffer to store the result into * @param result A buffer to store the result into
*/ */
template <uint8_t other_columns> template <uint8_t other_columns>
void Multiply(const Matrix<rows, columns> &other, void Mult(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const; Matrix<columns, other_columns> &result) const;
/** /**
@@ -55,7 +55,7 @@ public:
* @param result A buffer to store the result into * @param result A buffer to store the result into
* @note there is no problem if result == this * @note there is no problem if result == this
*/ */
void Multiply(float scalar, Matrix<rows, columns> &result) const; void Mult(float scalar, Matrix<rows, columns> &result) const;
/** /**
* @brief Invert this matrix * @brief Invert this matrix
@@ -117,13 +117,13 @@ public:
* @brief get the specified row of the matrix returned as a reference to the * @brief get the specified row of the matrix returned as a reference to the
* internal array * internal array
*/ */
std::array<float, columns> &operator[](uint8_t row_index){ std::array<float, columns> &operator[](uint8_t row_index) {
return this->matrix[row_index]; return this->matrix[row_index];
} }
Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other){ Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other) {
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx);
} }
} }
@@ -153,15 +153,15 @@ public:
*/ */
constexpr uint8_t GetColumnSize() { return columns; } constexpr uint8_t GetColumnSize() { return columns; }
void ToString(std::string & stringBuffer) const; void ToString(std::string &stringBuffer) const;
private: private:
/** /**
* @brief take the dot product of the two vectors * @brief take the dot product of the two vectors
*/ */
template <uint8_t vector_size> template <uint8_t vector_size>
float dotProduct(const Matrix<vector_size, 1> &vec1, static float dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<vector_size, 1> &vec2); const Matrix<1, vector_size> &vec2);
/** /**
* @brief Set all elements in this matrix to zero * @brief Set all elements in this matrix to zero
@@ -183,17 +183,19 @@ private:
constexpr bool isSquare() { return rows == columns; } constexpr bool isSquare() { return rows == columns; }
void setMatrixToArray(const std::array<float, rows*columns> & array); void setMatrixToArray(const std::array<float, rows * columns> &array);
std::array<std::array<float, columns>, rows> matrix; std::array<std::array<float, columns>, rows> matrix;
}; };
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::setMatrixToArray(const std::array<float, rows*columns> & array){ void Matrix<rows, columns>::setMatrixToArray(
const std::array<float, rows * columns> &array) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
uint16_t array_idx = uint16_t array_idx =
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(column_idx); static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
static_cast<uint16_t>(column_idx);
if (array_idx < array.size()) { if (array_idx < array.size()) {
this->matrix[row_idx][column_idx] = array[array_idx]; this->matrix[row_idx][column_idx] = array[array_idx];
} else { } else {
@@ -208,7 +210,7 @@ template <uint8_t rows, uint8_t columns> Matrix<rows, columns>::Matrix() {
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
Matrix<rows, columns>::Matrix(const std::array<float, rows*columns> &array) { Matrix<rows, columns>::Matrix(const std::array<float, rows * columns> &array) {
this->setMatrixToArray(array); this->setMatrixToArray(array);
} }
@@ -234,47 +236,48 @@ void Matrix<rows, columns>::Add(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const { Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx); result[row_idx][column_idx] =
this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
} }
} }
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Subtract(const Matrix<rows, columns> &other, void Matrix<rows, columns>::Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const { Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx); result[row_idx][column_idx] =
this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
} }
} }
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
template <uint8_t other_columns> template <uint8_t other_columns>
void Matrix<rows, columns>::Multiply( void Matrix<rows, columns>::Mult(const Matrix<rows, columns> &other,
const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const { Matrix<columns, other_columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
// get our row // get our row
Matrix<rows, 1> this_row; Matrix<1, columns> this_row;
this->GetRow(row_idx, this_row); this->GetRow(row_idx, this_row);
// get the other matrices column // get the other matrices column
Matrix<1, columns> other_column; Matrix<rows, 1> other_column;
other.GetColumn(column_idx, other_column); other.GetColumn(column_idx, other_column);
// transpose the other matrix's column // transpose the other matrix's column
Matrix<columns, 1> other_column_t; Matrix<1, rows> other_column_t;
other_column.Transpose(other_column_t); other_column.Transpose(other_column_t);
// the result's index is equal to the dot product of these two vectors // the result's index is equal to the dot product of these two vectors
result[row_idx][column_idx] = result[row_idx][column_idx] =
this->dotProduct(this_row, other_column_t); Matrix<rows, columns>::dotProduct(this_row, other_column_t);
} }
} }
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Multiply(float scalar, void Matrix<rows, columns>::Mult(float scalar,
Matrix<rows, columns> &result) const { Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
@@ -311,7 +314,7 @@ void Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const {
float determinant = this->Det(); float determinant = this->Det();
// scale the result by 1/determinant and we have our answer // scale the result by 1/determinant and we have our answer
result.Multiply(1 / determinant); result.Mult(1 / determinant);
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
@@ -327,7 +330,7 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const { void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
static_assert(this->isSquare(), "You can't square an non-square matrix."); static_assert(this->isSquare(), "You can't square an non-square matrix.");
this->Multiply(this, result); this->Mult(this, result);
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
@@ -438,17 +441,17 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::GetColumn(uint8_t column_index, void Matrix<rows, columns>::GetColumn(uint8_t column_index,
Matrix<rows, 1> &column) const { Matrix<rows, 1> &column) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
column[0][column_index] = this->Get(row_idx, column_index); column[row_idx][0] = this->Get(row_idx, column_index);
} }
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{ void Matrix<rows, columns>::ToString(std::string &stringBuffer) const {
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
stringBuffer += "|"; stringBuffer += "|";
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
stringBuffer += std::to_string(this->matrix[row_idx][column_idx]); stringBuffer += std::to_string(this->matrix[row_idx][column_idx]);
if(column_idx != columns - 1){ if (column_idx != columns - 1) {
stringBuffer += "\t"; stringBuffer += "\t";
} }
} }
@@ -458,11 +461,11 @@ void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
template <uint8_t vector_size> template <uint8_t vector_size>
float Matrix<rows, columns>::dotProduct(const Matrix<vector_size, 1> &vec1, float Matrix<rows, columns>::dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<vector_size, 1> &vec2) { const Matrix<1, vector_size> &vec2) {
float sum{0}; float sum{0};
for (uint8_t i{0}; i < vector_size; i++) { for (uint8_t i{0}; i < vector_size; i++) {
sum += vec1.Get(i, 0) * vec2.Get(i, 0); sum += vec1.Get(0, i) * vec2.Get(0, i);
} }
return sum; return sum;
@@ -516,8 +519,7 @@ void Matrix<rows, columns>::adjugate(Matrix<rows, columns> &result) const {
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) { for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1; float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1; sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
result[row_iter][column_iter] = result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign;
this->Get(row_iter, column_iter) * sign;
} }
} }
} }

View File

@@ -8,21 +8,64 @@
#include <array> #include <array>
#include <iostream> #include <iostream>
TEST_CASE("Matrix Addition", "Matrix::Add") { TEST_CASE("Elementary Matrix Operations", "Matrix::Add") {
std::array<float, 4> arr1{1, 2, 3, 4}; std::array<float, 4> arr1{1, 2, 3, 4};
std::array<float, 4> arr2{5, 6, 7, 8}; std::array<float, 4> arr2{5, 6, 7, 8};
Matrix<2, 2> mat1{arr1}; Matrix<2, 2> mat1{arr1};
Matrix<2, 2> mat2{arr2}; Matrix<2, 2> mat2{arr2};
Matrix<2, 2> mat3{};
SECTION("Initialization") {
// array initialization
REQUIRE(mat1.Get(0, 0) == 1);
REQUIRE(mat1.Get(0, 1) == 2);
REQUIRE(mat1.Get(1, 0) == 3);
REQUIRE(mat1.Get(1, 1) == 4);
// empty initialization
REQUIRE(mat3.Get(0, 0) == 0);
REQUIRE(mat3.Get(0, 1) == 0);
REQUIRE(mat3.Get(1, 0) == 0);
REQUIRE(mat3.Get(1, 1) == 0);
}
SECTION("Addition") {
std::string strBuf1 = ""; std::string strBuf1 = "";
mat1.ToString(strBuf1); mat1.ToString(strBuf1);
std::cout << strBuf1 << std::endl; std::cout << "Matrix 1:\n" << strBuf1 << std::endl;
Matrix<2, 2> mat3{};
mat1.Add(mat2, mat3); mat1.Add(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 6); REQUIRE(mat3.Get(0, 0) == 6);
REQUIRE(mat3.Get(0, 1) == 8); REQUIRE(mat3.Get(0, 1) == 8);
REQUIRE(mat3.Get(1, 0) == 10); REQUIRE(mat3.Get(1, 0) == 10);
REQUIRE(mat3.Get(1, 1) == 12); REQUIRE(mat3.Get(1, 1) == 12);
}
SECTION("Subtraction") {
mat1.Sub(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == -4);
REQUIRE(mat3.Get(0, 1) == -4);
REQUIRE(mat3.Get(1, 0) == -4);
REQUIRE(mat3.Get(1, 1) == -4);
}
SECTION("Multiplication") {
mat1.Mult(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 19);
REQUIRE(mat3.Get(0, 1) == 22);
REQUIRE(mat3.Get(1, 0) == 43);
REQUIRE(mat3.Get(1, 1) == 50);
}
SECTION("Scalar Multiplication") {
mat1.Mult(2, mat3);
REQUIRE(mat3.Get(0, 0) == 2);
REQUIRE(mat3.Get(0, 1) == 4);
REQUIRE(mat3.Get(1, 0) == 6);
REQUIRE(mat3.Get(1, 1) == 8);
}
} }