Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hipparchus-fitting/src/changes/changes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties>
<body>
<release version="4.1" date="TBD" description="TBD">
<action dev="bryan" type="update" issue="issues/462">
Replaced double[] arrays with Fittable interface in RANSAC fitting classes.
</action>
<action dev="bryan" type="add" issue="issues/424">
Added RANSAC algorithm for estimating the parameters of a mathematical model.
</action>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.hipparchus.fitting.ransac;

/**
* Interface for data points that can be used with {@link RansacFitter}.
* @since 4.1
*/
public interface Fittable {

/**
* Gets the n-dimensional point.
* @return the point array (beware, it may be a reference to an internal array)
*/
double[] getPoint();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public interface IModelFitter<M> {
* @param points set of observed data
* @return the fitted model parameters
*/
M fitModel(final List<double[]> points);
M fitModel(List<Fittable> points);

/**
* Computes the error between the model and an observed data.
Expand All @@ -41,5 +41,5 @@ public interface IModelFitter<M> {
* @param point observed data
* @return the error between the model and the observed data
*/
double computeModelError(final M model, final double[] point);
double computeModelError(M model, Fittable point);
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public PolynomialModelFitter(final int degree) {

/** {@inheritDoc} */
@Override
public Model fitModel(final List<double[]> points) {
public Model fitModel(final List<Fittable> points) {
// Reference: Wikipedia page "Polynomial regression"
final int size = points.size();
checkSampleSize(size);
Expand All @@ -95,8 +95,9 @@ public Model fitModel(final List<double[]> points) {
final double[][] x = new double[size][degree + 1];
final double[] y = new double[size];
for (int i = 0; i < size; i++) {
final double currentX = points.get(i)[0];
final double currentY = points.get(i)[1];
final double[] point = points.get(i).getPoint();
final double currentX = point[0];
final double currentY = point[1];
double value = 1.0;
for (int j = 0; j <= degree; j++) {
x[i][j] = value;
Expand All @@ -117,8 +118,8 @@ public Model fitModel(final List<double[]> points) {

/** {@inheritDoc}. */
@Override
public double computeModelError(final Model model, final double[] point) {
return FastMath.abs(point[1] - model.predict(point[0]));
public double computeModelError(final Model model, final Fittable point) {
return FastMath.abs(point.getPoint()[1] - model.predict(point.getPoint()[0]));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,19 @@ public RansacFitter(final IModelFitter<M> fitter, final int sampleSize,
* @param points set of observed data
* @return a java class containing the best estimate of the model parameters
*/
public RansacFitterOutputs<M> fit(final List<double[]> points) {
public RansacFitterOutputs<M> fit(final List<Fittable> points) {

// Initialize the best model data
final List<double[]> data = new ArrayList<>(points);
final List<Fittable> data = new ArrayList<>(points);
Optional<M> bestModel = Optional.empty();
List<double[]> bestInliers = new ArrayList<>();
List<Fittable> bestInliers = new ArrayList<>();

// Iterative loop to determine the best model
for (int iteration = 0; iteration < maxIterations; iteration++) {

// Random permute the set of observed data and determine the inliers
Collections.shuffle(data, random);
final List<double[]> inliers = determineCurrentInliersFromRandomlyPermutedPoints(data);
final List<Fittable> inliers = determineCurrentInliersFromRandomlyPermutedPoints(data);

// Verifies if the current inliers are fit better the model than the previous ones
if (isCurrentInliersSetBetterThanPreviousOne(inliers, bestInliers)) {
Expand All @@ -122,8 +122,8 @@ public RansacFitterOutputs<M> fit(final List<double[]> points) {
* @param permutedPoints randomly permuted data
* @return the list of inliers
*/
private List<double[]> determineCurrentInliersFromRandomlyPermutedPoints(final List<double[]> permutedPoints) {
M model = fitter.fitModel(permutedPoints.subList(0, sampleSize));
private List<Fittable> determineCurrentInliersFromRandomlyPermutedPoints(final List<Fittable> permutedPoints) {
final M model = fitter.fitModel(permutedPoints.subList(0, sampleSize));
return permutedPoints.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList());
}

Expand All @@ -133,7 +133,7 @@ private List<double[]> determineCurrentInliersFromRandomlyPermutedPoints(final L
* @param previous previous inliers
* @return true is the current inlier are better than the previous ones
*/
private boolean isCurrentInliersSetBetterThanPreviousOne(final List<double[]> current, final List<double[]> previous) {
private boolean isCurrentInliersSetBetterThanPreviousOne(final List<Fittable> current, final List<Fittable> previous) {
return current.size() > previous.size() && current.size() >= minInliers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ public class RansacFitterOutputs<M> {
private final Optional<M> bestModel;

/** List of points used to determine the best model parameters. */
private final List<double[]> bestInliers;
private final List<Fittable> bestInliers;

/**
* Constructor.
* @param bestModel best model parameters
* @param bestInliers list of points used to determine the best model parameters
* @param fitter mathematical model fitter used by RANSAC algorithm
*/
public RansacFitterOutputs(final Optional<M> bestModel, final List<double[]> bestInliers, final IModelFitter<M> fitter) {
public RansacFitterOutputs(final Optional<M> bestModel, final List<Fittable> bestInliers, final IModelFitter<M> fitter) {
this.bestModel = bestModel;
this.bestInliers = new ArrayList<>(bestInliers);
this.fitter = fitter;
Expand All @@ -62,7 +62,7 @@ public Optional<M> getBestModel() {
* Get the list of points used to determine the best model parameters.
* @return the list of points used to determine the best model parameters
*/
public List<double[]> getBestInliers() {
public List<Fittable> getBestInliers() {
return new ArrayList<>(bestInliers);
}

Expand All @@ -73,7 +73,7 @@ public List<double[]> getBestInliers() {
* @return the list of points below the given threshold based on the computed best model parameters
* (can be empty if the all points are above the threshold or if no best model has been found)
*/
public List<double[]> filterPointsBelowThreshold(final List<double[]> points, final double threshold) {
public List<Fittable> filterPointsBelowThreshold(final List<Fittable> points, final double threshold) {
return bestModel.map(model -> points.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
/**
* Random sample consensus (RANSAC) algorithm implementation.
* <p>
* Data points to be fitted must implement the {@link org.hipparchus.fitting.ransac.Fittable} interface.
* </p>
* @since 4.1
*/
package org.hipparchus.fitting.ransac;
12 changes: 7 additions & 5 deletions hipparchus-fitting/src/site/markdown/fitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,17 @@ RANSAC-based fitting of specific functions are provided through the following cl
* call the fit method of [RansacFitter](../apidocs/org/hipparchus/fitting/ransac/RansacFitter.html) with a List of observed data points as argument, which will return
a java class containing the parameters that best fit the given data points.

Data points must implement the [Fittable](../apidocs/org/hipparchus/fitting/ransac/Fittable.html) interface which provides a `getPoint()` method returning the point coordinates as a `double[]`.

The following example shows how to fit data with a polynomial model of degree 2.

// Collect data.
final List<double[]> obs = new ArrayList<>();
obs.add(0.0, -61.422);
obs.add(2.0, -42.28700013);
obs.add(4.0, -58.97612903);
final List<Fittable> obs = new ArrayList<>();
obs.add(new SimpleFittable(0.0, -61.422));
obs.add(new SimpleFittable(2.0, -42.28700013));
obs.add(new SimpleFittable(4.0, -58.97612903));
// ... Lots of lines omitted ...
obs.add(498.0, -67.39);
obs.add(new SimpleFittable(498.0, -67.39));

// Instantiate the model to fit.
final PolynomialModelFitter model = new PolynomialModelFitter(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@

class RansacFitterTest {

/** Simple implementation of {@link Fittable} wrapping a double array. */
private static final class SimpleFittable implements Fittable {
private final double[] point;
SimpleFittable(final double[] point) {
this.point = point.clone();
}
@Override
public double[] getPoint() {
return point.clone();
}
}

@Test
void testExceptionsOnInitialValues() {
assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), -1, 6, 1e-6, 10, 1), "-1 is smaller than the minimum (0)");
Expand All @@ -52,7 +64,7 @@ void testCanPerfectlyFitALineWithoutNoiseButWithSmallNumberOfOutliers() {
@Test
void testCanFitALineWithLargeNumberOfOutliers() throws IOException {
// This test reproduces the example provided in RANSAC wikipedia page. Results are strongly consistent
final List<double[]> points = loadData("line_dataset.csv");
final List<Fittable> points = loadData("line_dataset.csv");
final double standardDeviation = 0.6159842899599051;
final RansacFitterOutputs<PolynomialModelFitter.Model> fitted = new RansacFitter<>(new PolynomialModelFitter(1), 10, 100, standardDeviation / 3, 10, 1).fit(points);
Assertions.assertNotNull(fitted);
Expand All @@ -65,7 +77,7 @@ void testCanFitALineWithLargeNumberOfOutliers() throws IOException {
@Test
void testCanFitAPolynomialOfDegree2WithOutliers() throws IOException {
// Reference: https://forum.orekit.org/t/addition-of-ransac-algorithm/5102
final List<double[]> points = loadData("quadratic_dataset.csv");
final List<Fittable> points = loadData("quadratic_dataset.csv");
final double standardDeviation = 72.59099534185657;
final RansacFitterOutputs<PolynomialModelFitter.Model> fitted = new RansacFitter<>(new PolynomialModelFitter(2), 10, 1000, standardDeviation / 3, 10, 1).fit(points);
Assertions.assertNotNull(fitted);
Expand All @@ -90,24 +102,24 @@ private void doTestLineFittingWithSmallNumberOfOutliers(final double slopeDelta,
Assertions.assertEquals(numberOfTrueData, fitted.getBestInliers().size());
}

private List<double[]> generateLine(final int seed, final double expectedSlope, final double expectedIntercept,
final int trueDataCount, final int falseDataCount, final double noiseFactor) {
private List<Fittable> generateLine(final int seed, final double expectedSlope, final double expectedIntercept,
final int trueDataCount, final int falseDataCount, final double noiseFactor) {
final Random random = new Random(seed);
final PolynomialModelFitter.Model trueModel = new PolynomialModelFitter.Model(new double[]{expectedIntercept, expectedSlope});
final List<double[]> points = IntStream.range(0, trueDataCount)
.mapToObj(x -> new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor})
.collect(Collectors.toList());
points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new double[]{x * 3, random.nextDouble() * 20}).collect(Collectors.toList()));
final List<Fittable> points = IntStream.range(0, trueDataCount)
.mapToObj(x -> new SimpleFittable(new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor}))
.collect(Collectors.toList());
points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new SimpleFittable(new double[]{x * 3, random.nextDouble() * 20})).collect(Collectors.toList()));
return points;
}

private List<double[]> loadData(final String fileName) {
private List<Fittable> loadData(final String fileName) {
final InputStream inputStream = this.getClass().getResourceAsStream("/" + this.getClass().getSimpleName() + "/" + fileName);
Assertions.assertNotNull(inputStream, "Could not find resource " + fileName);
return IOUtil.readLines(inputStream)
.stream()
.map(line -> line.split(","))
.map(values -> new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])})
.map(values -> new SimpleFittable(new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])}))
.collect(Collectors.toList());
}

Expand All @@ -124,4 +136,4 @@ private PolynomialModelFitter.Model getBestModel(final RansacFitterOutputs<Polyn
return fitted.getBestModel().orElseThrow(() -> new RuntimeException("No model found"));
}

}
}
3 changes: 3 additions & 0 deletions src/changes/changes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties>
<body>
<release version="4.1" date="TBD" description="TBD">
<action dev="bryan" type="update" issue="issues/462">
Replaced double[] arrays with Fittable interface in RANSAC fitting classes.
</action>
<action dev="Marthym" type="update" issue="issues/460">
Allow small numerical tolerance in SmoothStepFactory.checkBetweenZeroAndOneIncluded
</action>
Expand Down
Loading