001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import org.ejml.simple.SimpleMatrix;
012
013/**
014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the
015 * UnscentedKalmanFilter class.
016 *
017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
018 * most publications. Unless you know better, this should be your default choice.
019 *
020 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
021 *
022 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic Inference in Dynamic
023 * State-Space Models" (Doctoral dissertation)
024 *
025 * @param <S> The dimensionality of the state. 2 * States + 1 weights will be generated.
026 */
027public class MerweScaledSigmaPoints<S extends Num> {
028  private final double m_alpha;
029  private final int m_kappa;
030  private final Nat<S> m_states;
031  private Matrix<?, N1> m_wm;
032  private Matrix<?, N1> m_wc;
033
034  /**
035   * Constructs a generator for Van der Merwe scaled sigma points.
036   *
037   * @param states an instance of Num that represents the number of states.
038   * @param alpha Determines the spread of the sigma points around the mean. Usually a small
039   *     positive value (1e-3).
040   * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
041   *     distributions, beta = 2 is optimal.
042   * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
043   */
044  public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
045    this.m_states = states;
046    this.m_alpha = alpha;
047    this.m_kappa = kappa;
048
049    computeWeights(beta);
050  }
051
052  /**
053   * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha,
054   * beta, and kappa.
055   *
056   * @param states an instance of Num that represents the number of states.
057   */
058  public MerweScaledSigmaPoints(Nat<S> states) {
059    this(states, 1e-3, 2, 3 - states.getNum());
060  }
061
062  /**
063   * Returns number of sigma points for each variable in the state x.
064   *
065   * @return The number of sigma points for each variable in the state x.
066   */
067  public int getNumSigmas() {
068    return 2 * m_states.getNum() + 1;
069  }
070
071  /**
072   * Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P)
073   * of the filter.
074   *
075   * @param x An array of the means.
076   * @param s Square-root covariance of the filter.
077   * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
078   *     dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
079   */
080  public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
081    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
082    double eta = Math.sqrt(lambda + m_states.getNum());
083
084    Matrix<S, S> U = s.times(eta);
085
086    // 2 * states + 1 by states
087    Matrix<S, ?> sigmas =
088        new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
089    sigmas.setColumn(0, x);
090    for (int k = 0; k < m_states.getNum(); k++) {
091      var xPlusU = x.plus(U.extractColumnVector(k));
092      var xMinusU = x.minus(U.extractColumnVector(k));
093      sigmas.setColumn(k + 1, xPlusU);
094      sigmas.setColumn(m_states.getNum() + k + 1, xMinusU);
095    }
096
097    return new Matrix<>(sigmas);
098  }
099
100  /**
101   * Computes the weights for the scaled unscented Kalman filter.
102   *
103   * @param beta Incorporates prior knowledge of the distribution of the mean.
104   */
105  private void computeWeights(double beta) {
106    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
107    double c = 0.5 / (m_states.getNum() + lambda);
108
109    Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
110    Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
111    wM.fill(c);
112    wC.fill(c);
113
114    wM.set(0, 0, lambda / (m_states.getNum() + lambda));
115    wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta));
116
117    this.m_wm = wM;
118    this.m_wc = wC;
119  }
120
121  /**
122   * Returns the weight for each sigma point for the mean.
123   *
124   * @return the weight for each sigma point for the mean.
125   */
126  public Matrix<?, N1> getWm() {
127    return m_wm;
128  }
129
130  /**
131   * Returns an element of the weight for each sigma point for the mean.
132   *
133   * @param element Element of vector to return.
134   * @return the element i's weight for the mean.
135   */
136  public double getWm(int element) {
137    return m_wm.get(element, 0);
138  }
139
140  /**
141   * Returns the weight for each sigma point for the covariance.
142   *
143   * @return the weight for each sigma point for the covariance.
144   */
145  public Matrix<?, N1> getWc() {
146    return m_wc;
147  }
148
149  /**
150   * Returns an element of the weight for each sigma point for the covariance.
151   *
152   * @param element Element of vector to return.
153   * @return The element I's weight for the covariance.
154   */
155  public double getWc(int element) {
156    return m_wc.get(element, 0);
157  }
158}