Got matrix multiplication working

This commit is contained in:
Quinn Henthorne
2024-12-11 09:09:08 -05:00
parent 1fb211912d
commit 9cbaeb2c27
3 changed files with 103 additions and 57 deletions

View File

@@ -70,5 +70,6 @@
"thread": "cpp",
"typeinfo": "cpp",
"variant": "cpp"
}
},
"clangd.enable": false
}

View File

@@ -37,7 +37,7 @@ public:
* @param result A buffer to store the result into
* @note there is no problem if result == this
*/
void Subtract(const Matrix<rows, columns> &other,
void Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const;
/**
@@ -46,7 +46,7 @@ public:
* @param result A buffer to store the result into
*/
template <uint8_t other_columns>
void Multiply(const Matrix<rows, columns> &other,
void Mult(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const;
/**
@@ -55,7 +55,7 @@ public:
* @param result A buffer to store the result into
* @note there is no problem if result == this
*/
void Multiply(float scalar, Matrix<rows, columns> &result) const;
void Mult(float scalar, Matrix<rows, columns> &result) const;
/**
* @brief Invert this matrix
@@ -160,8 +160,8 @@ private:
* @brief take the dot product of the two vectors
*/
template <uint8_t vector_size>
float dotProduct(const Matrix<vector_size, 1> &vec1,
const Matrix<vector_size, 1> &vec2);
static float dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<1, vector_size> &vec2);
/**
* @brief Set all elements in this matrix to zero
@@ -189,11 +189,13 @@ private:
};
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::setMatrixToArray(const std::array<float, rows*columns> & array){
void Matrix<rows, columns>::setMatrixToArray(
const std::array<float, rows * columns> &array) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
uint16_t array_idx =
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(column_idx);
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
static_cast<uint16_t>(column_idx);
if (array_idx < array.size()) {
this->matrix[row_idx][column_idx] = array[array_idx];
} else {
@@ -234,47 +236,48 @@ void Matrix<rows, columns>::Add(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
result[row_idx][column_idx] =
this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
}
}
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Subtract(const Matrix<rows, columns> &other,
void Matrix<rows, columns>::Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
result[row_idx][column_idx] =
this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
}
}
}
template <uint8_t rows, uint8_t columns>
template <uint8_t other_columns>
void Matrix<rows, columns>::Multiply(
const Matrix<rows, columns> &other,
void Matrix<rows, columns>::Mult(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
// get our row
Matrix<rows, 1> this_row;
Matrix<1, columns> this_row;
this->GetRow(row_idx, this_row);
// get the other matrices column
Matrix<1, columns> other_column;
Matrix<rows, 1> other_column;
other.GetColumn(column_idx, other_column);
// transpose the other matrix's column
Matrix<columns, 1> other_column_t;
Matrix<1, rows> other_column_t;
other_column.Transpose(other_column_t);
// the result's index is equal to the dot product of these two vectors
result[row_idx][column_idx] =
this->dotProduct(this_row, other_column_t);
Matrix<rows, columns>::dotProduct(this_row, other_column_t);
}
}
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Multiply(float scalar,
void Matrix<rows, columns>::Mult(float scalar,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
@@ -311,7 +314,7 @@ void Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const {
float determinant = this->Det();
// scale the result by 1/determinant and we have our answer
result.Multiply(1 / determinant);
result.Mult(1 / determinant);
}
template <uint8_t rows, uint8_t columns>
@@ -327,7 +330,7 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
static_assert(this->isSquare(), "You can't square an non-square matrix.");
this->Multiply(this, result);
this->Mult(this, result);
}
template <uint8_t rows, uint8_t columns>
@@ -438,7 +441,7 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::GetColumn(uint8_t column_index,
Matrix<rows, 1> &column) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
column[0][column_index] = this->Get(row_idx, column_index);
column[row_idx][0] = this->Get(row_idx, column_index);
}
}
@@ -458,11 +461,11 @@ void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
template <uint8_t rows, uint8_t columns>
template <uint8_t vector_size>
float Matrix<rows, columns>::dotProduct(const Matrix<vector_size, 1> &vec1,
const Matrix<vector_size, 1> &vec2) {
float Matrix<rows, columns>::dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<1, vector_size> &vec2) {
float sum{0};
for (uint8_t i{0}; i < vector_size; i++) {
sum += vec1.Get(i, 0) * vec2.Get(i, 0);
sum += vec1.Get(0, i) * vec2.Get(0, i);
}
return sum;
@@ -516,8 +519,7 @@ void Matrix<rows, columns>::adjugate(Matrix<rows, columns> &result) const {
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
result[row_iter][column_iter] =
this->Get(row_iter, column_iter) * sign;
result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign;
}
}
}

View File

@@ -8,17 +8,32 @@
#include <array>
#include <iostream>
TEST_CASE("Matrix Addition", "Matrix::Add") {
TEST_CASE("Elementary Matrix Operations", "Matrix::Add") {
std::array<float, 4> arr1{1, 2, 3, 4};
std::array<float, 4> arr2{5, 6, 7, 8};
Matrix<2, 2> mat1{arr1};
Matrix<2, 2> mat2{arr2};
Matrix<2, 2> mat3{};
SECTION("Initialization") {
// array initialization
REQUIRE(mat1.Get(0, 0) == 1);
REQUIRE(mat1.Get(0, 1) == 2);
REQUIRE(mat1.Get(1, 0) == 3);
REQUIRE(mat1.Get(1, 1) == 4);
// empty initialization
REQUIRE(mat3.Get(0, 0) == 0);
REQUIRE(mat3.Get(0, 1) == 0);
REQUIRE(mat3.Get(1, 0) == 0);
REQUIRE(mat3.Get(1, 1) == 0);
}
SECTION("Addition") {
std::string strBuf1 = "";
mat1.ToString(strBuf1);
std::cout << strBuf1 << std::endl;
std::cout << "Matrix 1:\n" << strBuf1 << std::endl;
Matrix<2, 2> mat3{};
mat1.Add(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 6);
@@ -26,3 +41,31 @@ TEST_CASE("Matrix Addition", "Matrix::Add") {
REQUIRE(mat3.Get(1, 0) == 10);
REQUIRE(mat3.Get(1, 1) == 12);
}
SECTION("Subtraction") {
mat1.Sub(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == -4);
REQUIRE(mat3.Get(0, 1) == -4);
REQUIRE(mat3.Get(1, 0) == -4);
REQUIRE(mat3.Get(1, 1) == -4);
}
SECTION("Multiplication") {
mat1.Mult(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 19);
REQUIRE(mat3.Get(0, 1) == 22);
REQUIRE(mat3.Get(1, 0) == 43);
REQUIRE(mat3.Get(1, 1) == 50);
}
SECTION("Scalar Multiplication") {
mat1.Mult(2, mat3);
REQUIRE(mat3.Get(0, 0) == 2);
REQUIRE(mat3.Get(0, 1) == 4);
REQUIRE(mat3.Get(1, 0) == 6);
REQUIRE(mat3.Get(1, 1) == 8);
}
}