001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import org.ejml.simple.SimpleMatrix;
012
013/**
014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the
015 * UnscentedKalmanFilter class.
016 *
017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
018 * most publications. Unless you know better, this should be your default choice.
019 *
020 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
021 *
022 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic Inference in Dynamic
023 * State-Space Models" (Doctoral dissertation)
024 */
025public class MerweScaledSigmaPoints<S extends Num> {
026  private final double m_alpha;
027  private final int m_kappa;
028  private final Nat<S> m_states;
029  private Matrix<?, N1> m_wm;
030  private Matrix<?, N1> m_wc;
031
032  /**
033   * Constructs a generator for Van der Merwe scaled sigma points.
034   *
035   * @param states an instance of Num that represents the number of states.
036   * @param alpha Determines the spread of the sigma points around the mean. Usually a small
037   *     positive value (1e-3).
038   * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
039   *     distributions, beta = 2 is optimal.
040   * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
041   */
042  public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
043    this.m_states = states;
044    this.m_alpha = alpha;
045    this.m_kappa = kappa;
046
047    computeWeights(beta);
048  }
049
050  /**
051   * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha,
052   * beta, and kappa.
053   *
054   * @param states an instance of Num that represents the number of states.
055   */
056  public MerweScaledSigmaPoints(Nat<S> states) {
057    this(states, 1e-3, 2, 3 - states.getNum());
058  }
059
060  /**
061   * Returns number of sigma points for each variable in the state x.
062   *
063   * @return The number of sigma points for each variable in the state x.
064   */
065  public int getNumSigmas() {
066    return 2 * m_states.getNum() + 1;
067  }
068
069  /**
070   * Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P)
071   * of the filter.
072   *
073   * @param x An array of the means.
074   * @param s Square-root covariance of the filter.
075   * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
076   *     dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
077   */
078  public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
079    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
080    double eta = Math.sqrt(lambda + m_states.getNum());
081
082    Matrix<S, S> U = s.times(eta);
083
084    // 2 * states + 1 by states
085    Matrix<S, ?> sigmas =
086        new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
087    sigmas.setColumn(0, x);
088    for (int k = 0; k < m_states.getNum(); k++) {
089      var xPlusU = x.plus(U.extractColumnVector(k));
090      var xMinusU = x.minus(U.extractColumnVector(k));
091      sigmas.setColumn(k + 1, xPlusU);
092      sigmas.setColumn(m_states.getNum() + k + 1, xMinusU);
093    }
094
095    return new Matrix<>(sigmas);
096  }
097
098  /**
099   * Computes the weights for the scaled unscented Kalman filter.
100   *
101   * @param beta Incorporates prior knowledge of the distribution of the mean.
102   */
103  private void computeWeights(double beta) {
104    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
105    double c = 0.5 / (m_states.getNum() + lambda);
106
107    Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
108    Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
109    wM.fill(c);
110    wC.fill(c);
111
112    wM.set(0, 0, lambda / (m_states.getNum() + lambda));
113    wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta));
114
115    this.m_wm = wM;
116    this.m_wc = wC;
117  }
118
119  /**
120   * Returns the weight for each sigma point for the mean.
121   *
122   * @return the weight for each sigma point for the mean.
123   */
124  public Matrix<?, N1> getWm() {
125    return m_wm;
126  }
127
128  /**
129   * Returns an element of the weight for each sigma point for the mean.
130   *
131   * @param element Element of vector to return.
132   * @return the element i's weight for the mean.
133   */
134  public double getWm(int element) {
135    return m_wm.get(element, 0);
136  }
137
138  /**
139   * Returns the weight for each sigma point for the covariance.
140   *
141   * @return the weight for each sigma point for the covariance.
142   */
143  public Matrix<?, N1> getWc() {
144    return m_wc;
145  }
146
147  /**
148   * Returns an element of the weight for each sigma point for the covariance.
149   *
150   * @param element Element of vector to return.
151   * @return The element I's weight for the covariance.
152   */
153  public double getWc(int element) {
154    return m_wc.get(element, 0);
155  }
156}