Got matrix multiplication working
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -70,5 +70,6 @@
|
||||
"thread": "cpp",
|
||||
"typeinfo": "cpp",
|
||||
"variant": "cpp"
|
||||
}
|
||||
},
|
||||
"clangd.enable": false
|
||||
}
|
||||
52
Matrix.hpp
52
Matrix.hpp
@@ -37,7 +37,7 @@ public:
|
||||
* @param result A buffer to store the result into
|
||||
* @note there is no problem if result == this
|
||||
*/
|
||||
void Subtract(const Matrix<rows, columns> &other,
|
||||
void Sub(const Matrix<rows, columns> &other,
|
||||
Matrix<rows, columns> &result) const;
|
||||
|
||||
/**
|
||||
@@ -46,7 +46,7 @@ public:
|
||||
* @param result A buffer to store the result into
|
||||
*/
|
||||
template <uint8_t other_columns>
|
||||
void Multiply(const Matrix<rows, columns> &other,
|
||||
void Mult(const Matrix<rows, columns> &other,
|
||||
Matrix<columns, other_columns> &result) const;
|
||||
|
||||
/**
|
||||
@@ -55,7 +55,7 @@ public:
|
||||
* @param result A buffer to store the result into
|
||||
* @note there is no problem if result == this
|
||||
*/
|
||||
void Multiply(float scalar, Matrix<rows, columns> &result) const;
|
||||
void Mult(float scalar, Matrix<rows, columns> &result) const;
|
||||
|
||||
/**
|
||||
* @brief Invert this matrix
|
||||
@@ -160,8 +160,8 @@ private:
|
||||
* @brief take the dot product of the two vectors
|
||||
*/
|
||||
template <uint8_t vector_size>
|
||||
float dotProduct(const Matrix<vector_size, 1> &vec1,
|
||||
const Matrix<vector_size, 1> &vec2);
|
||||
static float dotProduct(const Matrix<1, vector_size> &vec1,
|
||||
const Matrix<1, vector_size> &vec2);
|
||||
|
||||
/**
|
||||
* @brief Set all elements in this matrix to zero
|
||||
@@ -189,11 +189,13 @@ private:
|
||||
};
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
void Matrix<rows, columns>::setMatrixToArray(const std::array<float, rows*columns> & array){
|
||||
void Matrix<rows, columns>::setMatrixToArray(
|
||||
const std::array<float, rows * columns> &array) {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||
uint16_t array_idx =
|
||||
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(column_idx);
|
||||
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
|
||||
static_cast<uint16_t>(column_idx);
|
||||
if (array_idx < array.size()) {
|
||||
this->matrix[row_idx][column_idx] = array[array_idx];
|
||||
} else {
|
||||
@@ -234,47 +236,48 @@ void Matrix<rows, columns>::Add(const Matrix<rows, columns> &other,
|
||||
Matrix<rows, columns> &result) const {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||
result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
|
||||
result[row_idx][column_idx] =
|
||||
this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
void Matrix<rows, columns>::Subtract(const Matrix<rows, columns> &other,
|
||||
void Matrix<rows, columns>::Sub(const Matrix<rows, columns> &other,
|
||||
Matrix<rows, columns> &result) const {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||
result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
|
||||
result[row_idx][column_idx] =
|
||||
this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
template <uint8_t other_columns>
|
||||
void Matrix<rows, columns>::Multiply(
|
||||
const Matrix<rows, columns> &other,
|
||||
void Matrix<rows, columns>::Mult(const Matrix<rows, columns> &other,
|
||||
Matrix<columns, other_columns> &result) const {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||
// get our row
|
||||
Matrix<rows, 1> this_row;
|
||||
Matrix<1, columns> this_row;
|
||||
this->GetRow(row_idx, this_row);
|
||||
// get the other matrices column
|
||||
Matrix<1, columns> other_column;
|
||||
Matrix<rows, 1> other_column;
|
||||
other.GetColumn(column_idx, other_column);
|
||||
// transpose the other matrix's column
|
||||
Matrix<columns, 1> other_column_t;
|
||||
Matrix<1, rows> other_column_t;
|
||||
other_column.Transpose(other_column_t);
|
||||
|
||||
// the result's index is equal to the dot product of these two vectors
|
||||
result[row_idx][column_idx] =
|
||||
this->dotProduct(this_row, other_column_t);
|
||||
Matrix<rows, columns>::dotProduct(this_row, other_column_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
void Matrix<rows, columns>::Multiply(float scalar,
|
||||
void Matrix<rows, columns>::Mult(float scalar,
|
||||
Matrix<rows, columns> &result) const {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||
@@ -311,7 +314,7 @@ void Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const {
|
||||
float determinant = this->Det();
|
||||
|
||||
// scale the result by 1/determinant and we have our answer
|
||||
result.Multiply(1 / determinant);
|
||||
result.Mult(1 / determinant);
|
||||
}
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
@@ -327,7 +330,7 @@ template <uint8_t rows, uint8_t columns>
|
||||
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
|
||||
static_assert(this->isSquare(), "You can't square an non-square matrix.");
|
||||
|
||||
this->Multiply(this, result);
|
||||
this->Mult(this, result);
|
||||
}
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
@@ -438,7 +441,7 @@ template <uint8_t rows, uint8_t columns>
|
||||
void Matrix<rows, columns>::GetColumn(uint8_t column_index,
|
||||
Matrix<rows, 1> &column) const {
|
||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||
column[0][column_index] = this->Get(row_idx, column_index);
|
||||
column[row_idx][0] = this->Get(row_idx, column_index);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,11 +461,11 @@ void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
|
||||
|
||||
template <uint8_t rows, uint8_t columns>
|
||||
template <uint8_t vector_size>
|
||||
float Matrix<rows, columns>::dotProduct(const Matrix<vector_size, 1> &vec1,
|
||||
const Matrix<vector_size, 1> &vec2) {
|
||||
float Matrix<rows, columns>::dotProduct(const Matrix<1, vector_size> &vec1,
|
||||
const Matrix<1, vector_size> &vec2) {
|
||||
float sum{0};
|
||||
for (uint8_t i{0}; i < vector_size; i++) {
|
||||
sum += vec1.Get(i, 0) * vec2.Get(i, 0);
|
||||
sum += vec1.Get(0, i) * vec2.Get(0, i);
|
||||
}
|
||||
|
||||
return sum;
|
||||
@@ -516,8 +519,7 @@ void Matrix<rows, columns>::adjugate(Matrix<rows, columns> &result) const {
|
||||
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
|
||||
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
|
||||
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
|
||||
result[row_iter][column_iter] =
|
||||
this->Get(row_iter, column_iter) * sign;
|
||||
result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,17 +8,32 @@
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
TEST_CASE("Matrix Addition", "Matrix::Add") {
|
||||
TEST_CASE("Elementary Matrix Operations", "Matrix::Add") {
|
||||
std::array<float, 4> arr1{1, 2, 3, 4};
|
||||
std::array<float, 4> arr2{5, 6, 7, 8};
|
||||
Matrix<2, 2> mat1{arr1};
|
||||
Matrix<2, 2> mat2{arr2};
|
||||
Matrix<2, 2> mat3{};
|
||||
|
||||
SECTION("Initialization") {
|
||||
// array initialization
|
||||
REQUIRE(mat1.Get(0, 0) == 1);
|
||||
REQUIRE(mat1.Get(0, 1) == 2);
|
||||
REQUIRE(mat1.Get(1, 0) == 3);
|
||||
REQUIRE(mat1.Get(1, 1) == 4);
|
||||
|
||||
// empty initialization
|
||||
REQUIRE(mat3.Get(0, 0) == 0);
|
||||
REQUIRE(mat3.Get(0, 1) == 0);
|
||||
REQUIRE(mat3.Get(1, 0) == 0);
|
||||
REQUIRE(mat3.Get(1, 1) == 0);
|
||||
}
|
||||
|
||||
SECTION("Addition") {
|
||||
std::string strBuf1 = "";
|
||||
mat1.ToString(strBuf1);
|
||||
std::cout << strBuf1 << std::endl;
|
||||
std::cout << "Matrix 1:\n" << strBuf1 << std::endl;
|
||||
|
||||
Matrix<2, 2> mat3{};
|
||||
mat1.Add(mat2, mat3);
|
||||
|
||||
REQUIRE(mat3.Get(0, 0) == 6);
|
||||
@@ -26,3 +41,31 @@ TEST_CASE("Matrix Addition", "Matrix::Add") {
|
||||
REQUIRE(mat3.Get(1, 0) == 10);
|
||||
REQUIRE(mat3.Get(1, 1) == 12);
|
||||
}
|
||||
|
||||
SECTION("Subtraction") {
|
||||
mat1.Sub(mat2, mat3);
|
||||
|
||||
REQUIRE(mat3.Get(0, 0) == -4);
|
||||
REQUIRE(mat3.Get(0, 1) == -4);
|
||||
REQUIRE(mat3.Get(1, 0) == -4);
|
||||
REQUIRE(mat3.Get(1, 1) == -4);
|
||||
}
|
||||
|
||||
SECTION("Multiplication") {
|
||||
mat1.Mult(mat2, mat3);
|
||||
|
||||
REQUIRE(mat3.Get(0, 0) == 19);
|
||||
REQUIRE(mat3.Get(0, 1) == 22);
|
||||
REQUIRE(mat3.Get(1, 0) == 43);
|
||||
REQUIRE(mat3.Get(1, 1) == 50);
|
||||
}
|
||||
|
||||
SECTION("Scalar Multiplication") {
|
||||
mat1.Mult(2, mat3);
|
||||
|
||||
REQUIRE(mat3.Get(0, 0) == 2);
|
||||
REQUIRE(mat3.Get(0, 1) == 4);
|
||||
REQUIRE(mat3.Get(1, 0) == 6);
|
||||
REQUIRE(mat3.Get(1, 1) == 8);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user