Refactors the underlying data type to optimize performance

This commit is contained in:
Quinn Henthorne
2024-12-13 17:04:46 -05:00
parent 53f3766658
commit 941758e8ea
4 changed files with 39 additions and 35 deletions

View File

@@ -24,19 +24,18 @@ Matrix<rows, columns>::Matrix(Args... args) {
static_cast<uint16_t>(columns)}; static_cast<uint16_t>(columns)};
std::initializer_list<float> initList{static_cast<float>(args)...}; std::initializer_list<float> initList{static_cast<float>(args)...};
std::array<float, arraySize> data{};
// choose whichever buffer size is smaller for the copy length // choose whichever buffer size is smaller for the copy length
uint32_t minSize = uint32_t minSize =
std::min(arraySize, static_cast<uint16_t>(initList.size())); std::min(arraySize, static_cast<uint16_t>(initList.size()));
memcpy(data.begin(), initList.begin(), minSize * sizeof(float)); memcpy(this->matrix.begin(), initList.begin(), minSize * sizeof(float));
this->setMatrixToArray(data);
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
Matrix<rows, columns>::Matrix(const Matrix<rows, columns> &other) { Matrix<rows, columns>::Matrix(const Matrix<rows, columns> &other) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); this->matrix[row_idx * columns + column_idx] =
other.Get(row_idx, column_idx);
} }
} }
} }
@@ -50,9 +49,9 @@ void Matrix<rows, columns>::setMatrixToArray(
static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) + static_cast<uint16_t>(row_idx) * static_cast<uint16_t>(columns) +
static_cast<uint16_t>(column_idx); static_cast<uint16_t>(column_idx);
if (array_idx < array.size()) { if (array_idx < array.size()) {
this->matrix[row_idx][column_idx] = array[array_idx]; this->matrix[row_idx * columns + column_idx] = array[array_idx];
} else { } else {
this->matrix[row_idx][column_idx] = 0; this->matrix[row_idx * columns + column_idx] = 0;
} }
} }
} }
@@ -175,10 +174,9 @@ Matrix<rows, columns>::Transpose(Matrix<columns, rows> &result) const {
// explicitly define the determinant for a 2x2 matrix because it is definitely // explicitly define the determinant for a 2x2 matrix because it is definitely
// the fastest way to calculate a 2x2 matrix determinant // the fastest way to calculate a 2x2 matrix determinant
template <> float Matrix<0, 0>::Det() const { return 1e+6; } template <> float Matrix<0, 0>::Det() const { return 1e+6; }
template <> float Matrix<1, 1>::Det() const { return this->matrix[0][0]; } template <> float Matrix<1, 1>::Det() const { return this->matrix[0]; }
template <> float Matrix<2, 2>::Det() const { template <> float Matrix<2, 2>::Det() const {
return this->matrix[0][0] * this->matrix[1][1] - return this->matrix[0] * this->matrix[3] - this->matrix[1] * this->matrix[2];
this->matrix[0][1] * this->matrix[1][0];
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
@@ -191,7 +189,7 @@ float Matrix<rows, columns>::Det() const {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
// for odd indices the sign is negative // for odd indices the sign is negative
float sign = (column_idx % 2 == 0) ? 1 : -1; float sign = (column_idx % 2 == 0) ? 1 : -1;
determinant += sign * this->matrix[0][column_idx] * determinant += sign * this->matrix[column_idx] *
this->MinorMatrix(MinorMatrix, 0, column_idx).Det(); this->MinorMatrix(MinorMatrix, 0, column_idx).Det();
} }
@@ -233,14 +231,15 @@ float Matrix<rows, columns>::Get(uint8_t row_index,
return 1e+10; // TODO: We should throw something here instead of failing return 1e+10; // TODO: We should throw something here instead of failing
// quietly // quietly
} }
return this->matrix[row_index][column_index]; return this->matrix[row_index * columns + column_index];
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
Matrix<1, columns> & Matrix<1, columns> &
Matrix<rows, columns>::GetRow(uint8_t row_index, Matrix<rows, columns>::GetRow(uint8_t row_index,
Matrix<1, columns> &row) const { Matrix<1, columns> &row) const {
row = Matrix<1, columns>(this->matrix[row_index]); memcpy(&(row[0]), this->matrix.begin() + row_index * columns,
columns * sizeof(float));
return row; return row;
} }
@@ -261,7 +260,8 @@ void Matrix<rows, columns>::ToString(std::string &stringBuffer) const {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
stringBuffer += "|"; stringBuffer += "|";
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
stringBuffer += std::to_string(this->matrix[row_idx][column_idx]); stringBuffer +=
std::to_string(this->matrix[row_idx * columns + column_idx]);
if (column_idx != columns - 1) { if (column_idx != columns - 1) {
stringBuffer += "\t"; stringBuffer += "\t";
} }
@@ -274,10 +274,13 @@ template <uint8_t rows, uint8_t columns>
std::array<float, columns> &Matrix<rows, columns>:: std::array<float, columns> &Matrix<rows, columns>::
operator[](uint8_t row_index) { operator[](uint8_t row_index) {
if (row_index > rows - 1) { if (row_index > rows - 1) {
return this->matrix[0]; // TODO: We should throw something here instead of // TODO: We should throw something here instead of failing quietly.
// failing quietly. row_index = 0;
} }
return this->matrix[row_index]; // cursed reinterpret_cast that will help us fake having a nested array when
// we really don't
return *reinterpret_cast<std::array<float, columns> *>(
&(this->matrix[row_index * columns]));
} }
template <uint8_t rows, uint8_t columns> template <uint8_t rows, uint8_t columns>
@@ -285,7 +288,8 @@ Matrix<rows, columns> &Matrix<rows, columns>::
operator=(const Matrix<rows, columns> &other) { operator=(const Matrix<rows, columns> &other) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
this->matrix[row_idx][column_idx] = other.Get(row_idx, column_idx); this->matrix[row_idx * columns + column_idx] =
other.Get(row_idx, column_idx);
} }
} }
// return a reference to ourselves so you can chain together these functions // return a reference to ourselves so you can chain together these functions
@@ -351,7 +355,7 @@ template <uint8_t rows, uint8_t columns>
void Matrix<rows, columns>::Fill(float value) { void Matrix<rows, columns>::Fill(float value) {
for (uint8_t row_idx{0}; row_idx < rows; row_idx++) { for (uint8_t row_idx{0}; row_idx < rows; row_idx++) {
for (uint8_t column_idx{0}; column_idx < columns; column_idx++) { for (uint8_t column_idx{0}; column_idx < columns; column_idx++) {
this->matrix[row_idx][column_idx] = value; this->matrix[row_idx * columns + column_idx] = value;
} }
} }
} }

View File

@@ -198,7 +198,7 @@ private:
void setMatrixToArray(const std::array<float, rows * columns> &array); void setMatrixToArray(const std::array<float, rows * columns> &array);
std::array<std::array<float, columns>, rows> matrix; std::array<float, rows * columns> matrix;
}; };
#include "Matrix.cpp" #include "Matrix.cpp"

View File

@@ -1,14 +1,14 @@
Addition: 0.55 s Addition: 0.419 s
Subtraction: 0.548 s Subtraction: 0.421 s
Multiplication: 3.404 s Multiplication: 3.297 s
Scalar Multiplication: 0.453 s Scalar Multiplication: 0.329 s
Element Multiply: 0.347 s Element Multiply: 0.306 s
Element Divide: 0.347 s Element Divide: 0.302 s
Minor Matrix: 0.44 s Minor Matrix: 0.331 s
Determinant: 0.251 s Determinant: 0.177 s
Matrix of Minors: 1.044 s Matrix of Minors: 0.766 s
Invert: 0.262 s Invert: 0.183 s
Transpose: 0.26 s Transpose: 0.215 s
Normalize: 0.333 s Normalize: 0.315 s
GET ROW: 0.683 s GET ROW: 0.008 s
GET COLUMN: 0.427 s GET COLUMN: 0.43 s

View File

@@ -83,7 +83,7 @@ for new_timing in new_timings:
for old_timing in old_timings: for old_timing in old_timings:
if new_timing == old_timing: if new_timing == old_timing:
new_timing.difference = new_timing - old_timing new_timing.difference = new_timing - old_timing
if new_timing.difference >= 0.03: if abs(new_timing.difference) >= 0.03:
difference_increased += f"{new_timing.test_name}, " difference_increased += f"{new_timing.test_name}, "
def save_option(): def save_option():
@@ -102,7 +102,7 @@ for timing in new_timings:
print(timing.to_string_w_diff()) print(timing.to_string_w_diff())
if len(difference_increased) > 0: if len(difference_increased) > 0:
print("You increased the time it takes to run for:" + difference_increased) print("You've made major timing changes for:" + difference_increased)
save_option() save_option()
else: else:
print("No times have changed outside the margin of error.") print("No times have changed outside the margin of error.")