diff --git a/hipparchus-fitting/src/changes/changes.xml b/hipparchus-fitting/src/changes/changes.xml index 9e01f2013..a2e68b05b 100644 --- a/hipparchus-fitting/src/changes/changes.xml +++ b/hipparchus-fitting/src/changes/changes.xml @@ -50,6 +50,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Replaced double[] arrays with Fittable interface in RANSAC fitting classes. + Added RANSAC algorithm for estimating the parameters of a mathematical model. diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/Fittable.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/Fittable.java new file mode 100644 index 000000000..6b0521a3c --- /dev/null +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/Fittable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +/** + * Interface for data points that can be used with {@link RansacFitter}. + * @since 4.1 + */ +public interface Fittable { + + /** + * Gets the n-dimensional point. + * @return the point array (beware, it may be a reference to an internal array) + */ + double[] getPoint(); +} diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java index f3c1a6322..2413f7b14 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java @@ -30,7 +30,7 @@ public interface IModelFitter { * @param points set of observed data * @return the fitted model parameters */ - M fitModel(final List points); + M fitModel(List points); /** * Computes the error between the model and an observed data. @@ -41,5 +41,5 @@ public interface IModelFitter { * @param point observed data * @return the error between the model and the observed data */ - double computeModelError(final M model, final double[] point); + double computeModelError(M model, Fittable point); } diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java index 79d18cffa..646f81887 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java @@ -86,7 +86,7 @@ public PolynomialModelFitter(final int degree) { /** {@inheritDoc} */ @Override - public Model fitModel(final List points) { + public Model fitModel(final List points) { // Reference: Wikipedia page "Polynomial regression" final int size = points.size(); checkSampleSize(size); @@ -95,8 +95,9 @@ public Model fitModel(final List points) { final double[][] x = new double[size][degree + 1]; final double[] y = new double[size]; for (int i = 0; i < size; i++) { - final double currentX = points.get(i)[0]; - final double currentY = points.get(i)[1]; + final double[] point = points.get(i).getPoint(); + final double currentX = point[0]; + final double currentY = point[1]; double value = 1.0; for (int j = 0; j <= degree; j++) { x[i][j] = value; @@ -117,8 +118,8 @@ public Model fitModel(final List points) { /** {@inheritDoc}. */ @Override - public double computeModelError(final Model model, final double[] point) { - return FastMath.abs(point[1] - model.predict(point[0])); + public double computeModelError(final Model model, final Fittable point) { + return FastMath.abs(point.getPoint()[1] - model.predict(point.getPoint()[0])); } /** diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java index 65255f413..de193bb1f 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java @@ -91,19 +91,19 @@ public RansacFitter(final IModelFitter fitter, final int sampleSize, * @param points set of observed data * @return a java class containing the best estimate of the model parameters */ - public RansacFitterOutputs fit(final List points) { + public RansacFitterOutputs fit(final List points) { // Initialize the best model data - final List data = new ArrayList<>(points); + final List data = new ArrayList<>(points); Optional bestModel = Optional.empty(); - List bestInliers = new ArrayList<>(); + List bestInliers = new ArrayList<>(); // Iterative loop to determine the best model for (int iteration = 0; iteration < maxIterations; iteration++) { // Random permute the set of observed data and determine the inliers Collections.shuffle(data, random); - final List inliers = determineCurrentInliersFromRandomlyPermutedPoints(data); + final List inliers = determineCurrentInliersFromRandomlyPermutedPoints(data); // Verifies if the current inliers are fit better the model than the previous ones if (isCurrentInliersSetBetterThanPreviousOne(inliers, bestInliers)) { @@ -122,8 +122,8 @@ public RansacFitterOutputs fit(final List points) { * @param permutedPoints randomly permuted data * @return the list of inliers */ - private List determineCurrentInliersFromRandomlyPermutedPoints(final List permutedPoints) { - M model = fitter.fitModel(permutedPoints.subList(0, sampleSize)); + private List determineCurrentInliersFromRandomlyPermutedPoints(final List permutedPoints) { + final M model = fitter.fitModel(permutedPoints.subList(0, sampleSize)); return permutedPoints.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList()); } @@ -133,7 +133,7 @@ private List determineCurrentInliersFromRandomlyPermutedPoints(final L * @param previous previous inliers * @return true is the current inlier are better than the previous ones */ - private boolean isCurrentInliersSetBetterThanPreviousOne(final List current, final List previous) { + private boolean isCurrentInliersSetBetterThanPreviousOne(final List current, final List previous) { return current.size() > previous.size() && current.size() >= minInliers; } diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java index dabd026da..c933d39dd 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java @@ -36,7 +36,7 @@ public class RansacFitterOutputs { private final Optional bestModel; /** List of points used to determine the best model parameters. */ - private final List bestInliers; + private final List bestInliers; /** * Constructor. @@ -44,7 +44,7 @@ public class RansacFitterOutputs { * @param bestInliers list of points used to determine the best model parameters * @param fitter mathematical model fitter used by RANSAC algorithm */ - public RansacFitterOutputs(final Optional bestModel, final List bestInliers, final IModelFitter fitter) { + public RansacFitterOutputs(final Optional bestModel, final List bestInliers, final IModelFitter fitter) { this.bestModel = bestModel; this.bestInliers = new ArrayList<>(bestInliers); this.fitter = fitter; @@ -62,7 +62,7 @@ public Optional getBestModel() { * Get the list of points used to determine the best model parameters. * @return the list of points used to determine the best model parameters */ - public List getBestInliers() { + public List getBestInliers() { return new ArrayList<>(bestInliers); } @@ -73,7 +73,7 @@ public List getBestInliers() { * @return the list of points below the given threshold based on the computed best model parameters * (can be empty if the all points are above the threshold or if no best model has been found) */ - public List filterPointsBelowThreshold(final List points, final double threshold) { + public List filterPointsBelowThreshold(final List points, final double threshold) { return bestModel.map(model -> points.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList())) .orElse(Collections.emptyList()); } diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/package-info.java b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/package-info.java index 8bfa2f7d5..7c612aa1d 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/package-info.java +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/package-info.java @@ -16,6 +16,9 @@ */ /** * Random sample consensus (RANSAC) algorithm implementation. + *

+ * Data points to be fitted must implement the {@link org.hipparchus.fitting.ransac.Fittable} interface. + *

* @since 4.1 */ package org.hipparchus.fitting.ransac; diff --git a/hipparchus-fitting/src/site/markdown/fitting.md b/hipparchus-fitting/src/site/markdown/fitting.md index 4a3627131..a28a73b3c 100644 --- a/hipparchus-fitting/src/site/markdown/fitting.md +++ b/hipparchus-fitting/src/site/markdown/fitting.md @@ -102,15 +102,17 @@ RANSAC-based fitting of specific functions are provided through the following cl * call the fit method of [RansacFitter](../apidocs/org/hipparchus/fitting/ransac/RansacFitter.html) with a List of observed data points as argument, which will return a java class containing the parameters that best fit the given data points. +Data points must implement the [Fittable](../apidocs/org/hipparchus/fitting/ransac/Fittable.html) interface which provides a `getPoint()` method returning the point coordinates as a `double[]`. + The following example shows how to fit data with a polynomial model of degree 2. // Collect data. - final List obs = new ArrayList<>(); - obs.add(0.0, -61.422); - obs.add(2.0, -42.28700013); - obs.add(4.0, -58.97612903); + final List obs = new ArrayList<>(); + obs.add(new SimpleFittable(0.0, -61.422)); + obs.add(new SimpleFittable(2.0, -42.28700013)); + obs.add(new SimpleFittable(4.0, -58.97612903)); // ... Lots of lines omitted ... - obs.add(498.0, -67.39); + obs.add(new SimpleFittable(498.0, -67.39)); // Instantiate the model to fit. final PolynomialModelFitter model = new PolynomialModelFitter(2); diff --git a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java index 036f3f7c8..04905ec14 100644 --- a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java +++ b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java @@ -31,6 +31,18 @@ class RansacFitterTest { + /** Simple implementation of {@link Fittable} wrapping a double array. */ + private static final class SimpleFittable implements Fittable { + private final double[] point; + SimpleFittable(final double[] point) { + this.point = point.clone(); + } + @Override + public double[] getPoint() { + return point.clone(); + } + } + @Test void testExceptionsOnInitialValues() { assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), -1, 6, 1e-6, 10, 1), "-1 is smaller than the minimum (0)"); @@ -52,7 +64,7 @@ void testCanPerfectlyFitALineWithoutNoiseButWithSmallNumberOfOutliers() { @Test void testCanFitALineWithLargeNumberOfOutliers() throws IOException { // This test reproduces the example provided in RANSAC wikipedia page. Results are strongly consistent - final List points = loadData("line_dataset.csv"); + final List points = loadData("line_dataset.csv"); final double standardDeviation = 0.6159842899599051; final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(1), 10, 100, standardDeviation / 3, 10, 1).fit(points); Assertions.assertNotNull(fitted); @@ -65,7 +77,7 @@ void testCanFitALineWithLargeNumberOfOutliers() throws IOException { @Test void testCanFitAPolynomialOfDegree2WithOutliers() throws IOException { // Reference: https://forum.orekit.org/t/addition-of-ransac-algorithm/5102 - final List points = loadData("quadratic_dataset.csv"); + final List points = loadData("quadratic_dataset.csv"); final double standardDeviation = 72.59099534185657; final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(2), 10, 1000, standardDeviation / 3, 10, 1).fit(points); Assertions.assertNotNull(fitted); @@ -90,24 +102,24 @@ private void doTestLineFittingWithSmallNumberOfOutliers(final double slopeDelta, Assertions.assertEquals(numberOfTrueData, fitted.getBestInliers().size()); } - private List generateLine(final int seed, final double expectedSlope, final double expectedIntercept, - final int trueDataCount, final int falseDataCount, final double noiseFactor) { + private List generateLine(final int seed, final double expectedSlope, final double expectedIntercept, + final int trueDataCount, final int falseDataCount, final double noiseFactor) { final Random random = new Random(seed); final PolynomialModelFitter.Model trueModel = new PolynomialModelFitter.Model(new double[]{expectedIntercept, expectedSlope}); - final List points = IntStream.range(0, trueDataCount) - .mapToObj(x -> new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor}) - .collect(Collectors.toList()); - points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new double[]{x * 3, random.nextDouble() * 20}).collect(Collectors.toList())); + final List points = IntStream.range(0, trueDataCount) + .mapToObj(x -> new SimpleFittable(new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor})) + .collect(Collectors.toList()); + points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new SimpleFittable(new double[]{x * 3, random.nextDouble() * 20})).collect(Collectors.toList())); return points; } - private List loadData(final String fileName) { + private List loadData(final String fileName) { final InputStream inputStream = this.getClass().getResourceAsStream("/" + this.getClass().getSimpleName() + "/" + fileName); Assertions.assertNotNull(inputStream, "Could not find resource " + fileName); return IOUtil.readLines(inputStream) .stream() .map(line -> line.split(",")) - .map(values -> new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])}) + .map(values -> new SimpleFittable(new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])})) .collect(Collectors.toList()); } @@ -124,4 +136,4 @@ private PolynomialModelFitter.Model getBestModel(final RansacFitterOutputs new RuntimeException("No model found")); } -} \ No newline at end of file +} diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 593ddd83e..7a588c219 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -50,6 +50,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Replaced double[] arrays with Fittable interface in RANSAC fitting classes. + Allow small numerical tolerance in SmoothStepFactory.checkBetweenZeroAndOneIncluded