Reworked some of the matrix interface

This commit is contained in:
2025-02-08 18:06:43 -05:00
parent 713809a82b
commit 2385446ac5
4 changed files with 125 additions and 21 deletions

View File

@@ -34,6 +34,16 @@ Matrix<rows, columns>::Matrix(Args... args)
memcpy(this->matrix.begin(), initList.begin(), minSize * sizeof(float));
}
template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Identity()
{
this->Fill(0);
for (uint8_t idx{0}; idx < rows; idx++)
{
this->matrix[idx * columns + idx] = 1;
}
}
template <uint8_t rows, uint8_t columns>
Matrix<rows, columns>::Matrix(const Matrix<rows, columns> &other)
{
@@ -112,7 +122,6 @@ Matrix<rows, columns>::Mult(const Matrix<columns, other_columns> &other,
// allocate some buffers for all of our dot products
Matrix<1, columns> this_row;
Matrix<columns, 1> other_column;
Matrix<1, columns> other_column_t;
for (uint8_t row_idx{0}; row_idx < rows; row_idx++)
{
@@ -122,12 +131,10 @@ Matrix<rows, columns>::Mult(const Matrix<columns, other_columns> &other,
{
// get the other matrix'ss column
other.GetColumn(column_idx, other_column);
// transpose the other matrix's column
other_column.Transpose(other_column_t);
// the result's index is equal to the dot product of these two vectors
result[row_idx][column_idx] =
Matrix<rows, columns>::dotProduct(this_row, other_column_t);
Matrix<rows, columns>::dotProduct(this_row, other_column.Transpose());
}
}
@@ -150,14 +157,15 @@ Matrix<rows, columns>::Mult(float scalar, Matrix<rows, columns> &result) const
}
template <uint8_t rows, uint8_t columns>
Matrix<rows, columns> &
Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const
Matrix<rows, columns>
Matrix<rows, columns>::Invert() const
{
// since all matrix sizes have to be statically specified at compile time we
// can do this
static_assert(rows == columns,
"Your matrix isn't square and can't be inverted");
Matrix<rows, columns> result{};
// unfortunately we can't calculate this at compile time so we'll just reurn
// zeros
float determinant{this->Det()};
@@ -187,9 +195,10 @@ Matrix<rows, columns>::Invert(Matrix<rows, columns> &result) const
}
template <uint8_t rows, uint8_t columns>
Matrix<columns, rows> &
Matrix<rows, columns>::Transpose(Matrix<columns, rows> &result) const
Matrix<columns, rows>
Matrix<rows, columns>::Transpose() const
{
Matrix<columns, rows> result{};
for (uint8_t column_idx{0}; column_idx < rows; column_idx++)
{
for (uint8_t row_idx{0}; row_idx < columns; row_idx++)
@@ -340,14 +349,9 @@ template <uint8_t rows, uint8_t columns>
Matrix<rows, columns> &Matrix<rows, columns>::
operator=(const Matrix<rows, columns> &other)
{
for (uint8_t row_idx{0}; row_idx < rows; row_idx++)
{
for (uint8_t column_idx{0}; column_idx < columns; column_idx++)
{
this->matrix[row_idx * columns + column_idx] =
other.Get(row_idx, column_idx);
}
}
memcpy(this->matrix.begin(), other.matrix.begin(),
rows * columns * sizeof(float));
// return a reference to ourselves so you can chain together these functions
return *this;
}
@@ -525,4 +529,43 @@ Matrix<rows, columns>::Normalize(Matrix<rows, columns> &result) const
return result;
}
template <uint8_t rows, uint8_t columns>
template <uint8_t sub_rows, uint8_t sub_columns>
Matrix<sub_rows, sub_columns> &Matrix<rows, columns>::SubMatrix(Matrix<sub_rows, sub_columns> &buffer, uint8_t row_offset, uint8_t column_offset) const
{
for (uint8_t row_idx{0}; row_idx < sub_rows; row_idx++)
{
for (uint8_t column_idx{0}; column_idx < sub_columns; column_idx++)
{
buffer[row_idx][column_idx] =
this->Get(row_idx + row_offset, column_idx + column_offset);
}
}
return buffer;
}
template <uint8_t rows, uint8_t columns>
template <uint8_t sub_rows, uint8_t sub_columns>
void Matrix<rows, columns>::SetSubMatrix(const Matrix<sub_rows, sub_columns> &sub_matrix, uint8_t row_offset, uint8_t column_offset)
{
uint8_t corrected_sub_rows = sub_rows;
uint8_t corrected_sub_columns = sub_columns;
if (sub_rows + row_offset > rows)
{
corrected_sub_rows = rows - row_offset;
}
if (sub_columns + column_offset > columns)
{
corrected_sub_columns = columns - column_offset;
}
for (uint8_t row_idx{0}; row_idx < corrected_sub_rows; row_idx++)
{
for (uint8_t column_idx{0}; column_idx < corrected_sub_columns; column_idx++)
{
this->matrix[(row_idx + row_offset) * columns + column_idx + column_offset] = sub_matrix.Get(row_idx, column_idx);
}
}
}
#endif // MATRIX_H_

View File

@@ -39,6 +39,12 @@ public:
*/
template <typename... Args>
Matrix(Args... args);
/**
* @brief set the matrix diagonals to 1 and all other values to 0
*/
void Identity();
/**
* @brief Set all elements in this to value
*/
@@ -115,13 +121,13 @@ public:
* @param result A buffer to store the result into
* @warning this is super slow! Only call it if you absolutely have to!!!
*/
Matrix<rows, columns> &Invert(Matrix<rows, columns> &result) const;
Matrix<rows, columns> Invert() const;
/**
* @brief Transpose this matrix
* @param result A buffer to store the result into
*/
Matrix<columns, rows> &Transpose(Matrix<columns, rows> &result) const;
Matrix<columns, rows> Transpose() const;
/**
* @brief reduce the matrix so the sum of its elements equal 1
@@ -187,6 +193,12 @@ public:
Matrix<rows, columns> operator*(float scalar) const;
template <uint8_t sub_rows, uint8_t sub_columns>
Matrix<sub_rows, sub_columns> &SubMatrix(Matrix<sub_rows, sub_columns> &buffer, uint8_t row_offset, uint8_t column_offset) const;
template <uint8_t sub_rows, uint8_t sub_columns>
void SetSubMatrix(const Matrix<sub_rows, sub_columns> &sub_matrix, uint8_t row_offset, uint8_t column_offset);
protected:
std::array<float, rows * columns> matrix;
@@ -202,6 +214,9 @@ private:
static float dotProduct(const Matrix<vector_size, 1> &vec1,
const Matrix<vector_size, 1> &vec2);
static float dotProduct(const Matrix<1, 1> &vec1,
const Matrix<1, 1> &vec2) { return vec1.Get(0, 0) * vec2.Get(0, 0); }
Matrix<rows, columns> &adjugate(Matrix<rows, columns> &result) const;
void setMatrixToArray(const std::array<float, rows * columns> &array);

View File

@@ -30,13 +30,30 @@ float Quaternion::operator[](uint8_t index) const
return 1e+6;
}
void Quaternion::operator=(const Quaternion &other) const
{
static_cast<Matrix<1, 4>>(this->matrix) = static_cast<Matrix<1, 4>>(other.matrix);
}
Quaternion Quaternion::operator*(const Quaternion &other) const
{
Quaternion result{};
this->Q_Mult(other, result);
return result;
}
Quaternion Quaternion::operator*(float scalar) const
{
return Quaternion{this->w * scalar, this->v1 * scalar, this->v2 * scalar, this->v3 * scalar};
}
Quaternion Quaternion::operator+(const Quaternion &other) const
{
return Quaternion{this->w * other.w, this->v1 * other.v1, this->v2 * other.v2, this->v3 * other.v3};
}
Quaternion &
Quaternion::Q_Mult(Quaternion &other, Quaternion &buffer) const
Quaternion::Q_Mult(const Quaternion &other, Quaternion &buffer) const
{
// eq. 6
@@ -68,4 +85,16 @@ void Quaternion::Normalize()
this->v2 /= magnitude;
this->v3 /= magnitude;
this->w /= magnitude;
}
}
Matrix<3, 3> Quaternion::ToRotationMatrix() const
{
float xx = this->v1 * this->v1;
float yy = this->v2 * this->v2;
float zz = this->v3 * this->v3;
Matrix<3, 3> rotationMatrix{
1 - 2 * (yy - zz), 2 * (this->v1 * this->v2 - this->v3 * this->w), 2 * (this->v1 * this->v3 + this->v2 * this->w),
2 * (this->v1 * this->v2 + this->v3 * this->w), 1 - 2 * (xx - zz), 2 * (this->v2 * this->v3 - this->v1 * this->w),
2 * (this->v1 * this->v3 - this->v2 * this->w), 2 * (this->v2 * this->v3 + this->v1 * this->w), 1 - 2 * (xx - yy)};
return rotationMatrix;
};

View File

@@ -27,6 +27,21 @@ public:
*/
float operator[](uint8_t index) const;
/**
* @brief Assign one quaternion to another
*/
void operator=(const Quaternion &other) const;
/**
* @brief Do quaternion multiplication
*/
Quaternion operator*(const Quaternion &other) const;
/**
* @brief Multiply the quaternion by a scalar
*/
Quaternion operator*(float scalar) const;
/**
* @brief Add two quaternions together
* @param other The quaternion to add to this one
@@ -40,7 +55,7 @@ public:
* @param buffer The buffer to store the result in
* @return A reference to the buffer
*/
Quaternion &Q_Mult(Quaternion &other, Quaternion &buffer) const;
Quaternion &Q_Mult(const Quaternion &other, Quaternion &buffer) const;
/**
* @brief Rotate a quaternion by this quaternion
@@ -55,6 +70,8 @@ public:
*/
void Normalize();
Matrix<3,3> ToRotationMatrix() const;
// Give people an easy way to access the elements
float &w{matrix[0]};
float &v1{matrix[1]};