Got matrix multiplication working
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -70,5 +70,6 @@
|
|||||||
"thread": "cpp",
|
"thread": "cpp",
|
||||||
"typeinfo": "cpp",
|
"typeinfo": "cpp",
|
||||||
"variant": "cpp"
|
"variant": "cpp"
|
||||||
}
|
},
|
||||||
|
"clangd.enable": false
|
||||||
}
|
}
|
||||||
92
Matrix.hpp
92
Matrix.hpp
@@ -13,7 +13,7 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @brief Initialize a matrix with an array
|
* @brief Initialize a matrix with an array
|
||||||
*/
|
*/
|
||||||
Matrix(const std::array<float, rows*columns> &array);
|
Matrix(const std::array<float, rows * columns> &array);
|
||||||
|
|
||||||
// TODO: Figure out how to do this
|
// TODO: Figure out how to do this
|
||||||
/**
|
/**
|
||||||
@@ -37,8 +37,8 @@ public:
|
|||||||
* @param result A buffer to store the result into
|
* @param result A buffer to store the result into
|
||||||
* @note there is no problem if result == this
|
* @note there is no problem if result == this
|
||||||
*/
|
*/
|
||||||
void Subtract(const Matrix<rows, columns> &other,
|
void Sub(const Matrix<rows, columns> &other,
|
||||||
Matrix<rows, columns> &result) const;
|
Matrix<rows, columns> &result) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Matrix multiply the two matrices
|
* @brief Matrix multiply the two matrices
|
||||||
@@ -46,8 +46,8 @@ public:
|
|||||||
* @param result A buffer to store the result into
|
* @param result A buffer to store the result into
|
||||||
*/
|
*/
|
||||||
template <uint8_t other_columns>
|
template <uint8_t other_columns>
|
||||||
void Multiply(const Matrix<rows, columns> &other,
|
void Mult(const Matrix<rows, columns> &other,
|
||||||
Matrix<columns, other_columns> &result) const;
|
Matrix<columns, other_columns> &result) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Multiply the matrix by a scalar
|
* @brief Multiply the matrix by a scalar
|
||||||
@@ -55,7 +55,7 @@ public:
|
|||||||
* @param result A buffer to store the result into
|
* @param result A buffer to store the result into
|
||||||
* @note there is no problem if result == this
|
* @note there is no problem if result == this
|
||||||
*/
|
*/
|
||||||
void Multiply(float scalar, Matrix<rows, columns> &result) const;
|
void Mult(float scalar, Matrix<rows, columns> &result) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Invert this matrix
|
* @brief Invert this matrix
|
||||||
@@ -117,13 +117,13 @@ public:
|
|||||||
* @brief get the specified row of the matrix returned as a reference to the
|
* @brief get the specified row of the matrix returned as a reference to the
|
||||||
* internal array
|
* internal array
|
||||||
*/
|
*/
|
||||||
std::array<float, columns> &operator[](uint8_t row_index){
|
std::array<float, columns> &operator[](uint8_t row_index) {
|
||||||
return this->matrix[row_index];
|
return this->matrix[row_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other){
|
Matrix<rows, columns> &operator=(const Matrix<rows, columns> &other) {
|
||||||
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx);
|
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -153,15 +153,15 @@ public:
|
|||||||
*/
|
*/
|
||||||
constexpr uint8_t GetColumnSize() { return columns; }
|
constexpr uint8_t GetColumnSize() { return columns; }
|
||||||
|
|
||||||
void ToString(std::string & stringBuffer) const;
|
void ToString(std::string &stringBuffer) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief take the dot product of the two vectors
|
* @brief take the dot product of the two vectors
|
||||||
*/
|
*/
|
||||||
template <uint8_t vector_size>
|
template <uint8_t vector_size>
|
||||||
float dotProduct(const Matrix<vector_size, 1> &vec1,
|
static float dotProduct(const Matrix<1, vector_size> &vec1,
|
||||||
const Matrix<vector_size, 1> &vec2);
|
const Matrix<1, vector_size> &vec2);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Set all elements in this matrix to zero
|
* @brief Set all elements in this matrix to zero
|
||||||
@@ -183,17 +183,19 @@ private:
|
|||||||
|
|
||||||
constexpr bool isSquare() { return rows == columns; }
|
constexpr bool isSquare() { return rows == columns; }
|
||||||
|
|
||||||
void setMatrixToArray(const std::array<float, rows*columns> & array);
|
void setMatrixToArray(const std::array<float, rows * columns> &array);
|
||||||
|
|
||||||
std::array<std::array<float, columns>, rows> matrix;
|
std::array<std::array<float, columns>, rows> matrix;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
void Matrix<rows, columns>::setMatrixToArray(const std::array<float, rows*columns> & array){
|
void Matrix<rows, columns>::setMatrixToArray(
|
||||||
|
const std::array<float, rows * columns> &array) {
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
uint16_t array_idx =
|
uint16_t array_idx =
|
||||||
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(column_idx);
|
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
|
||||||
|
static_cast<uint16_t>(column_idx);
|
||||||
if (array_idx < array.size()) {
|
if (array_idx < array.size()) {
|
||||||
this->matrix[row_idx][column_idx] = array[array_idx];
|
this->matrix[row_idx][column_idx] = array[array_idx];
|
||||||
} else {
|
} else {
|
||||||
@@ -208,14 +210,14 @@ template <uint8_t rows, uint8_t columns> Matrix<rows, columns>::Matrix() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
Matrix<rows, columns>::Matrix(const std::array<float, rows*columns> &array) {
|
Matrix<rows, columns>::Matrix(const std::array<float, rows * columns> &array) {
|
||||||
this->setMatrixToArray(array);
|
this->setMatrixToArray(array);
|
||||||
}
|
}
|
||||||
|
|
||||||
// template <uint8_t rows, uint8_t columns>
|
// template <uint8_t rows, uint8_t columns>
|
||||||
// template <typename... Args>
|
// template <typename... Args>
|
||||||
// Matrix<rows, columns>::Matrix(Args&&... args){
|
// Matrix<rows, columns>::Matrix(Args&&... args){
|
||||||
|
|
||||||
// // Initialize a std::array with the arguments
|
// // Initialize a std::array with the arguments
|
||||||
// if(typeid(args) == typeid(std::array<float, 4>)){
|
// if(typeid(args) == typeid(std::array<float, 4>)){
|
||||||
// this->setMatrixToArray(args);
|
// this->setMatrixToArray(args);
|
||||||
@@ -226,7 +228,7 @@ Matrix<rows, columns>::Matrix(const std::array<float, rows*columns> &array) {
|
|||||||
// // now store the array in our internal matrix
|
// // now store the array in our internal matrix
|
||||||
// this->setMatrixToArray(values);
|
// this->setMatrixToArray(values);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// }
|
// }
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
@@ -234,48 +236,49 @@ void Matrix<rows, columns>::Add(const Matrix<rows, columns> &other,
|
|||||||
Matrix<rows, columns> &result) const {
|
Matrix<rows, columns> &result) const {
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
result[row_idx][column_idx] = this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
|
result[row_idx][column_idx] =
|
||||||
|
this->Get(row_idx, column_idx) + other.Get(row_idx, column_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
void Matrix<rows, columns>::Subtract(const Matrix<rows, columns> &other,
|
void Matrix<rows, columns>::Sub(const Matrix<rows, columns> &other,
|
||||||
Matrix<rows, columns> &result) const {
|
Matrix<rows, columns> &result) const {
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
result[row_idx][column_idx] = this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
|
result[row_idx][column_idx] =
|
||||||
|
this->Get(row_idx, column_idx) - other.Get(row_idx, column_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
template <uint8_t other_columns>
|
template <uint8_t other_columns>
|
||||||
void Matrix<rows, columns>::Multiply(
|
void Matrix<rows, columns>::Mult(const Matrix<rows, columns> &other,
|
||||||
const Matrix<rows, columns> &other,
|
Matrix<columns, other_columns> &result) const {
|
||||||
Matrix<columns, other_columns> &result) const {
|
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
// get our row
|
// get our row
|
||||||
Matrix<rows, 1> this_row;
|
Matrix<1, columns> this_row;
|
||||||
this->GetRow(row_idx, this_row);
|
this->GetRow(row_idx, this_row);
|
||||||
// get the other matrices column
|
// get the other matrices column
|
||||||
Matrix<1, columns> other_column;
|
Matrix<rows, 1> other_column;
|
||||||
other.GetColumn(column_idx, other_column);
|
other.GetColumn(column_idx, other_column);
|
||||||
// transpose the other matrix's column
|
// transpose the other matrix's column
|
||||||
Matrix<columns, 1> other_column_t;
|
Matrix<1, rows> other_column_t;
|
||||||
other_column.Transpose(other_column_t);
|
other_column.Transpose(other_column_t);
|
||||||
|
|
||||||
// the result's index is equal to the dot product of these two vectors
|
// the result's index is equal to the dot product of these two vectors
|
||||||
result[row_idx][column_idx] =
|
result[row_idx][column_idx] =
|
||||||
this->dotProduct(this_row, other_column_t);
|
Matrix<rows, columns>::dotProduct(this_row, other_column_t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
void Matrix<rows, columns>::Multiply(float scalar,
|
void Matrix<rows, columns>::Mult(float scalar,
|
||||||
Matrix<rows, columns> &result) const {
|
Matrix<rows, columns> &result) const {
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
result[row_idx][column_idx] = this->Get(row_idx, column_idx) * scalar;
|
result[row_idx][column_idx] = this->Get(row_idx, column_idx) * scalar;
|
||||||
@@ -311,7 +314,7 @@ void Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const {
|
|||||||
float determinant = this->Det();
|
float determinant = this->Det();
|
||||||
|
|
||||||
// scale the result by 1/determinant and we have our answer
|
// scale the result by 1/determinant and we have our answer
|
||||||
result.Multiply(1 / determinant);
|
result.Mult(1 / determinant);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
@@ -327,7 +330,7 @@ template <uint8_t rows, uint8_t columns>
|
|||||||
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
|
void Matrix<rows, columns>::Square(Matrix<rows, columns> &result) const {
|
||||||
static_assert(this->isSquare(), "You can't square an non-square matrix.");
|
static_assert(this->isSquare(), "You can't square an non-square matrix.");
|
||||||
|
|
||||||
this->Multiply(this, result);
|
this->Mult(this, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
@@ -424,7 +427,7 @@ void Matrix<rows, columns>::ElementDivide(const Matrix<rows, columns> &other,
|
|||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
float Matrix<rows, columns>::Get(uint8_t row_index,
|
float Matrix<rows, columns>::Get(uint8_t row_index,
|
||||||
uint8_t column_index) const {
|
uint8_t column_index) const {
|
||||||
return this->matrix[row_index][column_index];
|
return this->matrix[row_index][column_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,17 +441,17 @@ template <uint8_t rows, uint8_t columns>
|
|||||||
void Matrix<rows, columns>::GetColumn(uint8_t column_index,
|
void Matrix<rows, columns>::GetColumn(uint8_t column_index,
|
||||||
Matrix<rows, 1> &column) const {
|
Matrix<rows, 1> &column) const {
|
||||||
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
column[0][column_index] = this->Get(row_idx, column_index);
|
column[row_idx][0] = this->Get(row_idx, column_index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
|
void Matrix<rows, columns>::ToString(std::string &stringBuffer) const {
|
||||||
for(uint8_t row_idx{0}; row_idx < rows; row_idx++){
|
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
|
||||||
stringBuffer += "|";
|
stringBuffer += "|";
|
||||||
for(uint8_t column_idx{0}; column_idx < columns; column_idx++){
|
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
|
||||||
stringBuffer += std::to_string(this->matrix[row_idx][column_idx]);
|
stringBuffer += std::to_string(this->matrix[row_idx][column_idx]);
|
||||||
if(column_idx != columns - 1){
|
if (column_idx != columns - 1) {
|
||||||
stringBuffer += "\t";
|
stringBuffer += "\t";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -458,11 +461,11 @@ void Matrix<rows, columns>::ToString(std::string & stringBuffer) const{
|
|||||||
|
|
||||||
template <uint8_t rows, uint8_t columns>
|
template <uint8_t rows, uint8_t columns>
|
||||||
template <uint8_t vector_size>
|
template <uint8_t vector_size>
|
||||||
float Matrix<rows, columns>::dotProduct(const Matrix<vector_size, 1> &vec1,
|
float Matrix<rows, columns>::dotProduct(const Matrix<1, vector_size> &vec1,
|
||||||
const Matrix<vector_size, 1> &vec2) {
|
const Matrix<1, vector_size> &vec2) {
|
||||||
float sum{0};
|
float sum{0};
|
||||||
for (uint8_t i{0}; i < vector_size; i++) {
|
for (uint8_t i{0}; i < vector_size; i++) {
|
||||||
sum += vec1.Get(i, 0) * vec2.Get(i, 0);
|
sum += vec1.Get(0, i) * vec2.Get(0, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum;
|
return sum;
|
||||||
@@ -516,8 +519,7 @@ void Matrix<rows, columns>::adjugate(Matrix<rows, columns> &result) const {
|
|||||||
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
|
for (uint8_t column_iter{0}; column_iter < columns; column_iter++) {
|
||||||
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
|
float sign = ((row_iter + 1) % 2) == 0 ? -1 : 1;
|
||||||
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
|
sign *= ((column_iter + 1) % 2) == 0 ? -1 : 1;
|
||||||
result[row_iter][column_iter] =
|
result[row_iter][column_iter] = this->Get(row_iter, column_iter) * sign;
|
||||||
this->Get(row_iter, column_iter) * sign;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,21 +8,64 @@
|
|||||||
#include <array>
|
#include <array>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
TEST_CASE("Matrix Addition", "Matrix::Add") {
|
TEST_CASE("Elementary Matrix Operations", "Matrix::Add") {
|
||||||
std::array<float, 4> arr1{1, 2, 3, 4};
|
std::array<float, 4> arr1{1, 2, 3, 4};
|
||||||
std::array<float, 4> arr2{5, 6, 7, 8};
|
std::array<float, 4> arr2{5, 6, 7, 8};
|
||||||
Matrix<2, 2> mat1{arr1};
|
Matrix<2, 2> mat1{arr1};
|
||||||
Matrix<2, 2> mat2{arr2};
|
Matrix<2, 2> mat2{arr2};
|
||||||
|
|
||||||
std::string strBuf1 = "";
|
|
||||||
mat1.ToString(strBuf1);
|
|
||||||
std::cout << strBuf1 << std::endl;
|
|
||||||
|
|
||||||
Matrix<2, 2> mat3{};
|
Matrix<2, 2> mat3{};
|
||||||
mat1.Add(mat2, mat3);
|
|
||||||
|
|
||||||
REQUIRE(mat3.Get(0, 0) == 6);
|
SECTION("Initialization") {
|
||||||
REQUIRE(mat3.Get(0, 1) == 8);
|
// array initialization
|
||||||
REQUIRE(mat3.Get(1, 0) == 10);
|
REQUIRE(mat1.Get(0, 0) == 1);
|
||||||
REQUIRE(mat3.Get(1, 1) == 12);
|
REQUIRE(mat1.Get(0, 1) == 2);
|
||||||
|
REQUIRE(mat1.Get(1, 0) == 3);
|
||||||
|
REQUIRE(mat1.Get(1, 1) == 4);
|
||||||
|
|
||||||
|
// empty initialization
|
||||||
|
REQUIRE(mat3.Get(0, 0) == 0);
|
||||||
|
REQUIRE(mat3.Get(0, 1) == 0);
|
||||||
|
REQUIRE(mat3.Get(1, 0) == 0);
|
||||||
|
REQUIRE(mat3.Get(1, 1) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Addition") {
|
||||||
|
std::string strBuf1 = "";
|
||||||
|
mat1.ToString(strBuf1);
|
||||||
|
std::cout << "Matrix 1:\n" << strBuf1 << std::endl;
|
||||||
|
|
||||||
|
mat1.Add(mat2, mat3);
|
||||||
|
|
||||||
|
REQUIRE(mat3.Get(0, 0) == 6);
|
||||||
|
REQUIRE(mat3.Get(0, 1) == 8);
|
||||||
|
REQUIRE(mat3.Get(1, 0) == 10);
|
||||||
|
REQUIRE(mat3.Get(1, 1) == 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Subtraction") {
|
||||||
|
mat1.Sub(mat2, mat3);
|
||||||
|
|
||||||
|
REQUIRE(mat3.Get(0, 0) == -4);
|
||||||
|
REQUIRE(mat3.Get(0, 1) == -4);
|
||||||
|
REQUIRE(mat3.Get(1, 0) == -4);
|
||||||
|
REQUIRE(mat3.Get(1, 1) == -4);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Multiplication") {
|
||||||
|
mat1.Mult(mat2, mat3);
|
||||||
|
|
||||||
|
REQUIRE(mat3.Get(0, 0) == 19);
|
||||||
|
REQUIRE(mat3.Get(0, 1) == 22);
|
||||||
|
REQUIRE(mat3.Get(1, 0) == 43);
|
||||||
|
REQUIRE(mat3.Get(1, 1) == 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Scalar Multiplication") {
|
||||||
|
mat1.Mult(2, mat3);
|
||||||
|
|
||||||
|
REQUIRE(mat3.Get(0, 0) == 2);
|
||||||
|
REQUIRE(mat3.Get(0, 1) == 4);
|
||||||
|
REQUIRE(mat3.Get(1, 0) == 6);
|
||||||
|
REQUIRE(mat3.Get(1, 1) == 8);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user