001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import org.ejml.simple.SimpleMatrix;
012
013/**
014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the
015 * UnscentedKalmanFilter class.
016 *
017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
018 * most publications. S3SigmaPoints is generally preferred due to its greater performance with
019 * nearly identical accuracy.
020 *
021 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
022 *
023 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilistic Inference in Dynamic
024 * State-Space Models" (Doctoral dissertation)
025 *
026 * @param <S> The dimensionality of the state. 2 * States + 1 weights will be generated.
027 */
028public class MerweScaledSigmaPoints<S extends Num> implements SigmaPoints<S> {
029  private final double m_alpha;
030  private final int m_kappa;
031  private final Nat<S> m_states;
032  private Matrix<?, N1> m_wm;
033  private Matrix<?, N1> m_wc;
034
035  /**
036   * Constructs a generator for Van der Merwe scaled sigma points.
037   *
038   * @param states an instance of Num that represents the number of states.
039   * @param alpha Determines the spread of the sigma points around the mean. Usually a small
040   *     positive value (1e-3).
041   * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
042   *     distributions, beta = 2 is optimal.
043   * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
044   */
045  public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
046    this.m_states = states;
047    this.m_alpha = alpha;
048    this.m_kappa = kappa;
049
050    computeWeights(beta);
051  }
052
053  /**
054   * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha,
055   * beta, and kappa.
056   *
057   * @param states an instance of Num that represents the number of states.
058   */
059  public MerweScaledSigmaPoints(Nat<S> states) {
060    this(states, 1e-3, 2, 3 - states.getNum());
061  }
062
063  /**
064   * Returns number of sigma points for each variable in the state x.
065   *
066   * @return The number of sigma points for each variable in the state x.
067   */
068  @Override
069  public int getNumSigmas() {
070    return 2 * m_states.getNum() + 1;
071  }
072
073  /**
074   * Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root
075   * covariance (s) of the filter.
076   *
077   * @param x An array of the means.
078   * @param s Square-root covariance of the filter.
079   * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
080   *     dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
081   */
082  @Override
083  public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
084    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
085    double eta = Math.sqrt(lambda + m_states.getNum());
086
087    Matrix<S, S> U = s.times(eta);
088
089    // 2 * states + 1 by states
090    Matrix<S, ?> sigmas =
091        new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
092
093    // equation (17)
094    sigmas.setColumn(0, x);
095    for (int k = 0; k < m_states.getNum(); k++) {
096      var xPlusU = x.plus(U.extractColumnVector(k));
097      var xMinusU = x.minus(U.extractColumnVector(k));
098      sigmas.setColumn(k + 1, xPlusU);
099      sigmas.setColumn(m_states.getNum() + k + 1, xMinusU);
100    }
101
102    return new Matrix<>(sigmas);
103  }
104
105  /**
106   * Computes the weights for the scaled unscented Kalman filter.
107   *
108   * @param beta Incorporates prior knowledge of the distribution of the mean.
109   */
110  private void computeWeights(double beta) {
111    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
112    double c = 0.5 / (m_states.getNum() + lambda);
113
114    Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
115    Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
116    wM.fill(c);
117    wC.fill(c);
118
119    wM.set(0, 0, lambda / (m_states.getNum() + lambda));
120    wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta));
121
122    this.m_wm = wM;
123    this.m_wc = wC;
124  }
125
126  /**
127   * Returns the weight for each sigma point for the mean.
128   *
129   * @return the weight for each sigma point for the mean.
130   */
131  @Override
132  public Matrix<?, N1> getWm() {
133    return m_wm;
134  }
135
136  /**
137   * Returns an element of the weight for each sigma point for the mean.
138   *
139   * @param element Element of vector to return.
140   * @return the element i's weight for the mean.
141   */
142  @Override
143  public double getWm(int element) {
144    return m_wm.get(element, 0);
145  }
146
147  /**
148   * Returns the weight for each sigma point for the covariance.
149   *
150   * @return the weight for each sigma point for the covariance.
151   */
152  @Override
153  public Matrix<?, N1> getWc() {
154    return m_wc;
155  }
156
157  /**
158   * Returns an element of the weight for each sigma point for the covariance.
159   *
160   * @param element Element of vector to return.
161   * @return The element I's weight for the covariance.
162   */
163  @Override
164  public double getWc(int element) {
165    return m_wc.get(element, 0);
166  }
167}