001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import java.util.function.BiFunction;
012
013/**
014 * An Unscented Kalman Filter using sigma points and weights from Van der Merwe's 2004 dissertation.
015 * S3UKF is generally preferred due to its greater performance with nearly identical accuracy.
016 *
017 * @param <States> Number of states.
018 * @param <Inputs> Number of inputs.
019 * @param <Outputs> Number of outputs.
020 */
021public class MerweUKF<States extends Num, Inputs extends Num, Outputs extends Num>
022    extends UnscentedKalmanFilter<States, Inputs, Outputs> {
023  /**
024   * Constructs a Merwe Unscented Kalman Filter.
025   *
026   * <p>See <a
027   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
028   * for how to select the standard deviations.
029   *
030   * @param states A Nat representing the number of states.
031   * @param outputs A Nat representing the number of outputs.
032   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
033   * @param h A vector-valued function of x and u that returns the measurement vector.
034   * @param stateStdDevs Standard deviations of model states.
035   * @param measurementStdDevs Standard deviations of measurements.
036   * @param nominalDt Nominal discretization timestep in seconds.
037   */
038  public MerweUKF(
039      Nat<States> states,
040      Nat<Outputs> outputs,
041      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
042      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
043      Matrix<States, N1> stateStdDevs,
044      Matrix<Outputs, N1> measurementStdDevs,
045      double nominalDt) {
046    super(
047        new MerweScaledSigmaPoints<>(states),
048        states,
049        outputs,
050        f,
051        h,
052        stateStdDevs,
053        measurementStdDevs,
054        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
055        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
056        Matrix::minus,
057        Matrix::minus,
058        Matrix::plus,
059        nominalDt);
060  }
061
062  /**
063   * Constructs a Merwe Unscented Kalman filter with custom mean, residual, and addition functions.
064   * Using custom functions for arithmetic can be useful if you have angles in the state or
065   * measurements, because they allow you to correctly account for the modular nature of angle
066   * arithmetic.
067   *
068   * <p>See <a
069   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
070   * for how to select the standard deviations.
071   *
072   * @param states A Nat representing the number of states.
073   * @param outputs A Nat representing the number of outputs.
074   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
075   * @param h A vector-valued function of x and u that returns the measurement vector.
076   * @param stateStdDevs Standard deviations of model states.
077   * @param measurementStdDevs Standard deviations of measurements.
078   * @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set
079   *     of weights.
080   * @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a
081   *     given set of weights.
082   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
083   *     subtracts them.)
084   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
085   *     subtracts them.)
086   * @param addFuncX A function that adds two state vectors.
087   * @param nominalDt Nominal discretization timestep in seconds.
088   */
089  public MerweUKF(
090      Nat<States> states,
091      Nat<Outputs> outputs,
092      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
093      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
094      Matrix<States, N1> stateStdDevs,
095      Matrix<Outputs, N1> measurementStdDevs,
096      BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
097      BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
098      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
099      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
100      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
101      double nominalDt) {
102    super(
103        new MerweScaledSigmaPoints<>(states),
104        states,
105        outputs,
106        f,
107        h,
108        stateStdDevs,
109        measurementStdDevs,
110        meanFuncX,
111        meanFuncY,
112        residualFuncX,
113        residualFuncY,
114        addFuncX,
115        nominalDt);
116  }
117}