From ac8c70d5a413e26693069b60bb5365b21afcd556 Mon Sep 17 00:00:00 2001 From: Quinn Date: Mon, 9 Dec 2024 17:48:50 -0500 Subject: [PATCH] Started working on the matrix library --- Matrix.h | 203 +++++++++++++++++++++++++++++++++++++++++++++++++++++ Vector3D.h | 1 + 2 files changed, 204 insertions(+) create mode 100644 Matrix.h diff --git a/Matrix.h b/Matrix.h new file mode 100644 index 0000000..5ef8848 --- /dev/null +++ b/Matrix.h @@ -0,0 +1,203 @@ +#include +#include +#include + +template +class Matrix{ + public: + /** + * @brief Element-wise matrix addition + * @param other the other matrix to add to this one + * @param result A buffer to store the result into + */ + void Add(const Matrix & other, Matrix & result) const; + + /** + * @brief Element-wise subtract matrix + * @param other the other matrix to subtract from this one + * @param result A buffer to store the result into + */ + void Subtract(const Matrix & other, Matrix & result) const; + + /** + * @brief Matrix multiply the two matrices + * @param other the other matrix to multiply into this one + * @param result A buffer to store the result into + */ + template + void Multiply(const Matrix & other, Matrix & result) const; + + /** + * @brief Multiply the matrix by a scalar + * @param scalar the the scalar to multiply by + * @param result A buffer to store the result into + */ + void Multiply(float scalar, Matrix & result) const; + + /** + * @brief Invert this matrix + * @param result A buffer to store the result into + */ + void Invert(Matrix & result) const; + + /** + * @brief Transpose this matrix + * @param result A buffer to store the result into + */ + void Transpose(Matrix & result) const; + + /** + * @brief Square this matrix + * @param result A buffer to store the result into + */ + void Square(Matrix & result) const; + + /** + * @return Get the determinant of the matrix + */ + float Det(); + + /** + * @brief Element-wise multiply the two matrices + * @param other the other matrix to multiply into this one + * @param result A buffer to store the result into + */ + void ElementMultiply(const Matrix & other, Matrix & result) const; + + /** + * @brief Element-wise divide the two matrices + * @param other the other matrix to multiply into this one + * @param result A buffer to store the result into + */ + void ElementDivide(const Matrix & other, Matrix & result) const; + + /** + * @brief Get an element from the matrix + * @param row the row index of the element + * @param column the column index of the element + * @return The value of the element you want to get + */ + float & Get(uint8_t row_index, uint8_t column_index) const; + + /** + * @brief Get a row from the matrix + * @param row_index the row index to get + * @param row a buffer to write the row into + */ + void GetRow(uint8_t row_index, Matrix & row) const; + + /** + * @brief Get a row from the matrix + * @param column_index the row index to get + * @param column a buffer to write the row into + */ + void GetColumn(uint8_t column_index, Matrix<1, columns> & column) const; + + /** + * @brief Get the number of rows in this matrix + */ + constexpr uint8_t GetRowSize(){return rows;} + + /** + * @brief Get the number of columns in this matrix + */ + constexpr uint8_t GetColumnSize(){return columns;} + + private: + + /** + * @brief take the dot product of the two vectors + */ + template + float dotProduct(const Matrix & vec1, const Matrix & vec2); + + /** + * @brief Set all elements in this matrix to zero + */ + void zeroMatrix(); + + void matrixOfMinors(const Matrix & input, Matrix & result) const; + + void adjugate(const Matrix & input, Matrix & result) const; + + std::array, rows> matrix; +}; + +template +void Matrix::Add(const Matrix & other, Matrix & result) const{ + for(uint8_t row{0}; row < rows; row++){ + for(uint8_t column{0}; column < columns; column++){ + result.Get(row, column) = this->Get(row, column) + other.Get(row, column); + } + } +} + +template +void Matrix::Subtract(const Matrix & other, Matrix & result) const{ + for(uint8_t row{0}; row < rows; row++){ + for(uint8_t column{0}; column < columns; column++){ + result.Get(row, column) = this->Get(row, column) - other.Get(row, column); + } + } +} + +template +template +void Matrix::Multiply(const Matrix & other, Matrix & result) const{ + for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ + for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ + // get our row + Matrix this_row; + this->GetRow(row_idx, this_row); + // get the other matrices column + Matrix<1, columns> other_column; + other.GetColumn(column_idx, other_column); + // transpose the other matrix's column + Matrix other_column_t; + other_column.Transpose(other_column_t); + + // the result's index is equal to the dot product of these two vectors + result.Get(row_idx, column_idx) = this->dotProduct(this_row, other_column_t); + } + } +} + +template +void Matrix::Multiply(float scalar, Matrix & result) const{ + for(uint8_t row_idx{0}; row_idx < rows; row_idx++){ + for(uint8_t column_idx{0}; column_idx < columns; column_idx++){ + result.Get(row_idx, column_idx) = this->Get(row_idx, column_idx) * scalar; + } + } +} + +template +void Matrix::Invert(Matrix & result) const{ + // since all matrix sizes have to be statically specified at compile time we can do this + static_assert(rows == columns, "Your matrix isn't square and can't be inverted"); + + // unfortunately we can't calculate this at compile time so we'll just reurn zeros + if(this->Det() < 0){ + // you can't invert a matrix with a negative determinant + result.zeroMatrix(); + return; + } + + // how to calculate the inverse: https://www.mathsisfun.com/algebra/matrix-inverse-minors-cofactors-adjugate.html + + // calculate the matrix of minors + Matrix minors{}; + this->matrixOfMinors(this, minors); + + // now adjugate the matrix and save it in our output + this->adjugate(minors, result); + float determinant = this->Det(); + + // scale the result by 1/determinant and we have our answer + result.Multiply(1/determinant); +} + +template +void Matrix::Transpose(Matrix & result) const{ + +} \ No newline at end of file diff --git a/Vector3D.h b/Vector3D.h index cb60300..f625f16 100644 --- a/Vector3D.h +++ b/Vector3D.h @@ -2,6 +2,7 @@ #include #include +#include template class V3D{