Got matrix multiplication working

This commit is contained in:
Quinn Henthorne
2024-12-11 09:09:08 -05:00
parent 1fb211912d
commit 9cbaeb2c27
3 changed files with 103 additions and 57 deletions

View File

@@ -70,5 +70,6 @@
"thread": "cpp",
"typeinfo": "cpp",
"variant": "cpp"
}
},
"clangd.enable": false
}

View File

@@ -13,7 +13,7 @@ public:
/**
* @brief Initialize a matrix with an array
*/
Matrix(const std::array<float, rows*columns> &array);
Matrix(const std::array<float, rows * columns> &array);
// TODO: Figure out how to do this
/**
@@ -37,8 +37,8 @@ public:
* @param result A buffer to store the result into
* @note there is no problem if result == this
*/
void Subtract(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const;
void Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const;
/**
* @brief Matrix multiply the two matrices
@@ -46,8 +46,8 @@ public:
* @param result A buffer to store the result into
*/
template <uint8_t other_columns>
void Multiply(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const;
void Mult(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const;
/**
* @brief Multiply the matrix by a scalar
@@ -55,7 +55,7 @@ public:
* @param result A buffer to store the result into
* @note there is no problem if result == this
*/
void Multiply(float scalar, Matrix<rows, columns> &result) const;
void Mult(float scalar, Matrix<rows, columns> &result) const;
/**
* @brief Invert this matrix
@@ -117,13 +117,13 @@ public:
* @brief get the specified row of the matrix returned as a reference to the
* internal array
*/
std::array<float, columns> &operator[](uint8_t row_index){
std::array<float, columns> &operator[](uint8_t row_index) {
return this->matrix[row_index];
}
Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other){
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){
Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx);
}
}
@@ -153,15 +153,15 @@ public:
*/
constexpr uint8_t GetColumnSize() { return columns; }
void ToString(std::string & stringBuffer) const;
void ToString(std::string &stringBuffer) const;
private:
/**
* @brief take the dot product of the two vectors
*/
template <uint8_t vector_size>
float dotProduct(const Matrix<vector_size, 1> &vec1,
const Matrix<vector_size, 1> &vec2);
static float dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<1, vector_size> &vec2);
/**
* @brief Set all elements in this matrix to zero
@@ -183,17 +183,19 @@ private:
constexpr bool isSquare() { return rows == columns; }
void setMatrixToArray(const std::array<float, rows*columns> & array);
void setMatrixToArray(const std::array<float, rows * columns> &array);
std::array<std::array<float, columns>, rows> matrix;
};
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::setMatrixToArray(const std::array<float, rows*columns> & array){
void Matrix<rows, columns>::setMatrixToArray(
const std::array<float, rows * columns> &array) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
uint16_t array_idx =
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(column_idx);
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
static_cast<uint16_t>(column_idx);
if (array_idx < array.size()) {
this->matrix[row_idx][column_idx] = array[array_idx];
} else {
@@ -208,7 +210,7 @@ template <uint8_t rows, uint8_t columns> Matrix<rows, columns>::Matrix() {
}
template <uint8_t rows, uint8_t columns>
Matrix<rows, columns>::Matrix(const std::array<float, rows*columns> &array) {
Matrix<rows, columns>::Matrix(const std::array<float, rows * columns> &array) {
this->setMatrixToArray(array);
}
@@ -234,48 +236,49 @@ void Matrix<rows, columns>::Add(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
result[row_idx][column_idx] =
this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
}
}
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Subtract(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const {
void Matrix<rows, columns>::Sub(const Matrix<rows, columns> &other,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
result[row_idx][column_idx] =
this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
}
}
}
template <uint8_t rows, uint8_t columns>
template <uint8_t other_columns>
void Matrix<rows, columns>::Multiply(
const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const {
void Matrix<rows, columns>::Mult(const Matrix<rows, columns> &other,
Matrix<columns, other_columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
// get our row
Matrix<rows, 1> this_row;
Matrix<1, columns> this_row;
this->GetRow(row_idx, this_row);
// get the other matrices column
Matrix<1, columns> other_column;
Matrix<rows, 1> other_column;
other.GetColumn(column_idx, other_column);
// transpose the other matrix's column
Matrix<columns, 1> other_column_t;
Matrix<1, rows> other_column_t;
other_column.Transpose(other_column_t);
// the result's index is equal to the dot product of these two vectors
result[row_idx][column_idx] =
this->dotProduct(this_row, other_column_t);
Matrix<rows, columns>::dotProduct(this_row, other_column_t);
}
}
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Multiply(float scalar,
Matrix<rows, columns> &result) const {
void Matrix<rows, columns>::Mult(float scalar,
Matrix<rows, columns> &result) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
result[row_idx][column_idx] = this->Get(row_idx, column_idx) * scalar;
@@ -311,7 +314,7 @@ void Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const {
float determinant = this->Det();
// scale the result by 1/determinant and we have our answer
result.Multiply(1 / determinant);
result.Mult(1 / determinant);
}
template <uint8_t rows, uint8_t columns>
@@ -327,7 +330,7 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
static_assert(this->isSquare(), "You can't square an non-square matrix.");
this->Multiply(this, result);
this->Mult(this, result);
}
template <uint8_t rows, uint8_t columns>
@@ -424,7 +427,7 @@ void Matrix<rows, columns>::ElementDivide(const Matrix<rows, columns> &other,
template <uint8_t rows, uint8_t columns>
float Matrix<rows, columns>::Get(uint8_t row_index,
uint8_t column_index) const {
uint8_t column_index) const {
return this->matrix[row_index][column_index];
}
@@ -438,17 +441,17 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::GetColumn(uint8_t column_index,
Matrix<rows, 1> &column) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
column[0][column_index] = this->Get(row_idx, column_index);
column[row_idx][0] = this->Get(row_idx, column_index);
}
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){
void Matrix<rows, columns>::ToString(std::string &stringBuffer) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
stringBuffer += "|";
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
stringBuffer += std::to_string(this->matrix[row_idx][column_idx]);
if(column_idx != columns - 1){
if (column_idx != columns - 1) {
stringBuffer += "\t";
}
}
@@ -458,11 +461,11 @@ void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
template <uint8_t rows, uint8_t columns>
template <uint8_t vector_size>
float Matrix<rows, columns>::dotProduct(const Matrix<vector_size, 1> &vec1,
const Matrix<vector_size, 1> &vec2) {
float Matrix<rows, columns>::dotProduct(const Matrix<1, vector_size> &vec1,
const Matrix<1, vector_size> &vec2) {
float sum{0};
for (uint8_t i{0}; i < vector_size; i++) {
sum += vec1.Get(i, 0) * vec2.Get(i, 0);
sum += vec1.Get(0, i) * vec2.Get(0, i);
}
return sum;
@@ -516,8 +519,7 @@ void Matrix<rows, columns>::adjugate(Matrix<rows, columns> &result) const {
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
result[row_iter][column_iter] =
this->Get(row_iter, column_iter) * sign;
result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign;
}
}
}

View File

@@ -8,21 +8,64 @@
#include <array>
#include <iostream>
TEST_CASE("Matrix Addition", "Matrix::Add") {
TEST_CASE("Elementary Matrix Operations", "Matrix::Add") {
std::array<float, 4> arr1{1, 2, 3, 4};
std::array<float, 4> arr2{5, 6, 7, 8};
Matrix<2, 2> mat1{arr1};
Matrix<2, 2> mat2{arr2};
std::string strBuf1 = "";
mat1.ToString(strBuf1);
std::cout << strBuf1 << std::endl;
Matrix<2, 2> mat3{};
mat1.Add(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 6);
REQUIRE(mat3.Get(0, 1) == 8);
REQUIRE(mat3.Get(1, 0) == 10);
REQUIRE(mat3.Get(1, 1) == 12);
SECTION("Initialization") {
// array initialization
REQUIRE(mat1.Get(0, 0) == 1);
REQUIRE(mat1.Get(0, 1) == 2);
REQUIRE(mat1.Get(1, 0) == 3);
REQUIRE(mat1.Get(1, 1) == 4);
// empty initialization
REQUIRE(mat3.Get(0, 0) == 0);
REQUIRE(mat3.Get(0, 1) == 0);
REQUIRE(mat3.Get(1, 0) == 0);
REQUIRE(mat3.Get(1, 1) == 0);
}
SECTION("Addition") {
std::string strBuf1 = "";
mat1.ToString(strBuf1);
std::cout << "Matrix 1:\n" << strBuf1 << std::endl;
mat1.Add(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 6);
REQUIRE(mat3.Get(0, 1) == 8);
REQUIRE(mat3.Get(1, 0) == 10);
REQUIRE(mat3.Get(1, 1) == 12);
}
SECTION("Subtraction") {
mat1.Sub(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == -4);
REQUIRE(mat3.Get(0, 1) == -4);
REQUIRE(mat3.Get(1, 0) == -4);
REQUIRE(mat3.Get(1, 1) == -4);
}
SECTION("Multiplication") {
mat1.Mult(mat2, mat3);
REQUIRE(mat3.Get(0, 0) == 19);
REQUIRE(mat3.Get(0, 1) == 22);
REQUIRE(mat3.Get(1, 0) == 43);
REQUIRE(mat3.Get(1, 1) == 50);
}
SECTION("Scalar Multiplication") {
mat1.Mult(2, mat3);
REQUIRE(mat3.Get(0, 0) == 2);
REQUIRE(mat3.Get(0, 1) == 4);
REQUIRE(mat3.Get(1, 0) == 6);
REQUIRE(mat3.Get(1, 1) == 8);
}
}