From 0c59839dfa3164fb8055744ed4a18888849619ed Mon Sep 17 00:00:00 2001 From: Quinn Date: Fri, 24 Nov 2023 17:09:00 -0500 Subject: [PATCH] Cleaned up scan matching implimentation. --- src/ScanGraph/ScanGraph.java | 234 +++++++++++++++++++++-------------- src/Vector/Vector.java | 13 ++ 2 files changed, 151 insertions(+), 96 deletions(-) diff --git a/src/ScanGraph/ScanGraph.java b/src/ScanGraph/ScanGraph.java index 5751545..bb719b9 100644 --- a/src/ScanGraph/ScanGraph.java +++ b/src/ScanGraph/ScanGraph.java @@ -29,117 +29,36 @@ public class ScanGraph extends Graph{ * @return null if no match can be found, or an existing scan the matches the new scan. */ private ScanPoint getAssociatedScan(ScanPoint newScan) { - + ScanMatcher matcher = new ScanMatcher(); // go through all of our available scans and try to match the new scan with the old scans. If no match can be found return null for (Vertex v : adjList.keySet()) { ScanPoint referenceScan = (ScanPoint) v; - // p is the newScan and q is the referenceScan - CorrespondenceMatrix correspondenceMatrix = new CorrespondenceMatrix(newScan, referenceScan); + for(int i = 0; i < 5; i++) { + // calculate the rotation and translation matrices between the new scan and the reference scan + matcher.calculateRotationAndTranslationMatrices(referenceScan, newScan); - // compute the average position of the new scan - Vector averagePosition = new Vector(0, 0); - int invalidPoints = 0; - for (Vector point : newScan.getScan()) { - if (point != null) { - averagePosition = averagePosition.add(point); - } - else{ - invalidPoints++; + // update the new scan with the rotation matrix and translation vector + newScan = matcher.applyRotationAndTranslationMatrices(newScan); + + // calculate the error between the new scan and the reference scan + float error = matcher.getError(referenceScan, newScan); + + // if the error is less than some threshold, then we have found a match + if (error < 0.1) { + return referenceScan; } } - SimpleMatrix averagePositionVector = new SimpleMatrix(averagePosition.div(newScan.getScan().size() - invalidPoints).toArray()); - - // compute the average position of the reference scan - Vector averageReferencePosition = new Vector(0, 0); - invalidPoints = 0; - for (Vector point : referenceScan.getScan()) { - if (point != null) { - averageReferencePosition = averageReferencePosition.add(point); - } - else{ - invalidPoints++; - } - } - SimpleMatrix averageReferencePositionVector = new SimpleMatrix(averageReferencePosition.div(referenceScan.getScan().size() - invalidPoints).toArray()); - - // compute the cross covariance matrix which is given by the formula: - // covariance = the sum from 1 to N of (p_i) * (q_i)^T - // where p_i is the ith point in the new scan and q_i is the ith point in the reference scan and N is the number of points in the scan - // the cross covariance matrix is a 2x2 matrix - float[][] crossCovarianceMatrix = new float[2][2]; - for (int i = 0; i < correspondenceMatrix.getOldPointIndices().size(); i++) { - int oldIndex = correspondenceMatrix.getOldPointIndices().get(i); - int newIndex = correspondenceMatrix.getNewPointIndices().get(i); - Vector oldPoint = referenceScan.getScan().get(oldIndex); - Vector newPoint = newScan.getScan().get(newIndex); - if (oldPoint != null && newPoint != null) { - Vector oldPointCentered = oldPoint.sub(averageReferencePosition); - Vector newPointCentered = newPoint.sub(averagePosition); - crossCovarianceMatrix[0][0] += oldPointCentered.x * newPointCentered.x; - crossCovarianceMatrix[0][1] += oldPointCentered.x * newPointCentered.y; - crossCovarianceMatrix[1][0] += oldPointCentered.y * newPointCentered.x; - crossCovarianceMatrix[1][1] += oldPointCentered.y * newPointCentered.y; - } - } - - // convert the cross covariance matrix to a simple matrix from ejml - SimpleMatrix crossCovarianceMatrixSimple = new SimpleMatrix(crossCovarianceMatrix); - // perform the single value decomposition on the cross covariance matrix - SimpleSVD svd = crossCovarianceMatrixSimple.svd(); - // get the rotation matrix from the svd - SimpleMatrix rotationMatrix = (SimpleMatrix) svd.getU().mult(svd.getV().transpose()); - // get the translation vector from the svd - SimpleMatrix translationVector = averageReferencePositionVector.minus(rotationMatrix.mult(averagePositionVector)); - - // update the new scan with the rotation matrix and translation vector - for (int i = 0; i < newScan.getScan().size(); i++) { - Vector point = newScan.getScan().get(i); - if (point != null) { - SimpleMatrix pointMatrix = new SimpleMatrix(point.toArray()); - SimpleMatrix newPointMatrix = rotationMatrix.mult(pointMatrix).plus(translationVector); - newScan.getScan().set(i, new Vector((float) newPointMatrix.get(0), (float) newPointMatrix.get(1))); - } - } - - // calculate the error between the new scan and the reference scan - float error = 0; - for (int i = 0; i < correspondenceMatrix.getOldPointIndices().size(); i++) { - int oldIndex = correspondenceMatrix.getOldPointIndices().get(i); - int newIndex = correspondenceMatrix.getNewPointIndices().get(i); - Vector oldPoint = referenceScan.getScan().get(oldIndex); - Vector newPoint = newScan.getScan().get(newIndex); - if (oldPoint != null && newPoint != null) { - error += correspondenceMatrix.getDistances().get(i); - } - } - error /= correspondenceMatrix.getOldPointIndices().size(); - - // if the error is less than some threshold, then we have found a match - if (error < 0.1) { - return referenceScan; - } - - // TODO: iteratively update the scan up to 5 times before determining that there is no match. - - } - return null; } - private void singleValueDecomposition(float[][] matrix){ - // compute the single value decomposition of the matrix - - // matrix multiply the matrix by its transpose - - } - } /** * @brief A class to hold the correspondence matrix between two scans - * The correspondence matrix is a 3xN matrix where N is the number of valid points in the scan + * The correspondence matrix is a 3xN matrix where N is the number of valid points in the scan. + * This calculates the closest point in the old scan for each point in the new scan and gets rid of redundant closest points. */ class CorrespondenceMatrix{ private ArrayList oldPointIndices = new ArrayList<>(); @@ -162,6 +81,11 @@ class CorrespondenceMatrix{ return this.distances; } + /** + * @brief Calculate the correspondence matrix between two scans + * @param newScan the new scan + * @param referenceScan the reference scan + */ private void calculateCorrespondenceMatrix(ScanPoint newScan, ScanPoint referenceScan){ // compute the correspondence matrix between the two scans. It is a 3xN matrix where N is the number of points in the scan // Row 1 is the index of the point in the old scan @@ -178,13 +102,16 @@ class CorrespondenceMatrix{ // go through all of the points in the new scan and find the closest point in the old scan for (int newPointIndex = 0; newPointIndex < newScan.getScan().size(); newPointIndex++) { Vector newPoint = newScan.getScan().get(newPointIndex); + // if the new point is null, then skip it if (newPoint == null) { continue; } + // find the closest point in the old scan float closestDistance = Float.MAX_VALUE; int closestIndex = -1; for (int j = 0; j < referenceScan.getScan().size(); j++) { Vector oldPoint = referenceScan.getScan().get(j); + // if the old point is null, then skip it if (oldPoint == null) { continue; } @@ -215,3 +142,118 @@ class CorrespondenceMatrix{ } } } + +class ScanMatcher{ + // A 2x2 matrix describing a rotation to apply to the new scan + SimpleMatrix rotationMatrix; + + // A 2x1 matrix describing a translation to apply to the new scan + SimpleMatrix translationVector; + + ScanMatcher(){ + } + + /** + * @brief Compute the average position of the scan + * @param scan the scan to compute the average position of + * @return a 2x1 matrix containing the x,y coordinates of the average position of the scan + */ + private SimpleMatrix averageScanPosition(ScanPoint scan){ + Vector averagePosition = new Vector(0, 0); + int invalidPoints = 0; + for (Vector point : scan.getScan()) { + if (point != null) { + averagePosition = averagePosition.add(point); + } + else{ + invalidPoints++; + } + } + return new SimpleMatrix(averagePosition.div(scan.getScan().size() - invalidPoints).toArray()); + } + + /** + * @brief Compute the cross covariance matrix between the new scan and the reference scan + * @return a 2x2 matrix containing the cross covariance matrix + */ + private SimpleMatrix crossCovarianceMatrix(ScanPoint referenceScan, ScanPoint newScan){ + Vector referenceScanAveragePosition = new Vector(averageScanPosition(referenceScan)); + Vector newScanAveragePosition = new Vector(averageScanPosition(newScan)); + + CorrespondenceMatrix correspondenceMatrix = new CorrespondenceMatrix(newScan, referenceScan); + + // compute the cross covariance matrix which is given by the formula: + // covariance = the sum from 1 to N of (p_i) * (q_i)^T + // where p_i is the ith point in the new scan and q_i is the ith point in the reference scan and N is the number of points in the scan + // the cross covariance matrix is a 2x2 matrix + float[][] crossCovarianceMatrix = new float[2][2]; + for (int i = 0; i < correspondenceMatrix.getOldPointIndices().size(); i++) { + int oldIndex = correspondenceMatrix.getOldPointIndices().get(i); + int newIndex = correspondenceMatrix.getNewPointIndices().get(i); + Vector oldPoint = referenceScan.getScan().get(oldIndex); + Vector newPoint = newScan.getScan().get(newIndex); + if (oldPoint != null && newPoint != null) { + Vector oldPointCentered = oldPoint.sub(referenceScanAveragePosition); + Vector newPointCentered = newPoint.sub(newScanAveragePosition); + crossCovarianceMatrix[0][0] += oldPointCentered.x * newPointCentered.x; + crossCovarianceMatrix[0][1] += oldPointCentered.x * newPointCentered.y; + crossCovarianceMatrix[1][0] += oldPointCentered.y * newPointCentered.x; + crossCovarianceMatrix[1][1] += oldPointCentered.y * newPointCentered.y; + } + } + return new SimpleMatrix(crossCovarianceMatrix); + } + + /** + * @brief Compute the rotation and translation matrices between the new scan and the reference scan. Then cache them as private variables. + * The rotation matrix is a 2x2 matrix and the translation vector is a 2x1 matrix + */ + public void calculateRotationAndTranslationMatrices(ScanPoint referenceScan, ScanPoint newScan){ + // compute the rotation matrix which is given by the formula: + // R = V * U^T + // where V and U are the singular value decomposition of the cross covariance matrix + // the rotation matrix is a 2x2 matrix + SimpleMatrix crossCovarianceMatrixSimple = crossCovarianceMatrix(referenceScan, newScan); + SimpleSVD svd = crossCovarianceMatrixSimple.svd(); + this.rotationMatrix = svd.getU().mult(svd.getV().transpose()); + + SimpleMatrix newScanAveragePosition = averageScanPosition(newScan); + SimpleMatrix referenceScanAveragePosition = averageScanPosition(referenceScan); + this.translationVector = referenceScanAveragePosition.minus(rotationMatrix.mult(newScanAveragePosition)); + } + + public SimpleMatrix getRotationMatrix(){ + return this.rotationMatrix; + } + + public SimpleMatrix getTranslationVector(){ + return this.translationVector; + } + + public ScanPoint applyRotationAndTranslationMatrices(ScanPoint newScan){ + // apply the rotation matrix and translation vector to the new scan + for (int i = 0; i < newScan.getScan().size(); i++) { + Vector point = newScan.getScan().get(i); + if (point != null) { + SimpleMatrix pointMatrix = new SimpleMatrix(point.toArray()); + SimpleMatrix newPointMatrix = rotationMatrix.mult(pointMatrix).plus(translationVector); + newScan.getScan().set(i, new Vector((float) newPointMatrix.get(0), (float) newPointMatrix.get(1))); + } + } + return newScan; + } + + public float getError(ScanPoint referenceScan, ScanPoint newScan){ + // calculate the error between the new scan and the reference scan + // q is reference scan and p is new scan + // error is given as abs(Q_mean - R * P_mean) + // where Q_mean is the average position of the reference scan + // P_mean is the average position of the new scan + // R is the rotation matrix + + SimpleMatrix newScanAveragePosition = averageScanPosition(newScan); + SimpleMatrix referenceScanAveragePosition = averageScanPosition(referenceScan); + SimpleMatrix error = referenceScanAveragePosition.minus(rotationMatrix.mult(newScanAveragePosition)); + return (float) error.elementSum(); + } +} diff --git a/src/Vector/Vector.java b/src/Vector/Vector.java index 11166bd..c388589 100644 --- a/src/Vector/Vector.java +++ b/src/Vector/Vector.java @@ -1,5 +1,6 @@ package Vector; +import org.ejml.simple.SimpleMatrix; import processing.core.PApplet; import static java.lang.Math.*; @@ -22,6 +23,18 @@ public class Vector { this.y = y; this.z = z; } + + public Vector(SimpleMatrix matrix){ + // initialize x,y if matrix is 2x1 and x,y,z if matrix is 3x1 + if(matrix.getNumRows() == 2){ + this.x = (float)matrix.get(0,0); + this.y = (float)matrix.get(1,0); + }else if(matrix.getNumRows() == 3){ + this.x = (float)matrix.get(0,0); + this.y = (float)matrix.get(1,0); + this.z = (float)matrix.get(2,0); + } + } public Vector add(Vector other){ return new Vector(this.x + other.x, this.y + other.y, this.z + other.z);