001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.Pair;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM;
018import org.ejml.simple.SimpleMatrix;
019
020/**
021 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
022 * true system state. This is useful because many states cannot be measured directly as a result of
023 * sensor noise, or because the state is "hidden".
024 *
025 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
026 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
027 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
028 * amount of the difference between the actual measurements and the measurements predicted by the
029 * model.
030 *
031 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the
032 * error covariance using sigma points chosen to approximate the true probability distribution.
033 *
034 * <p>For more on the underlying math, read <a
035 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
036 * chapter 9 "Stochastic control theory".
037 *
038 * <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more
039 * information about the SR-UKF, see <a
040 * href="https://www.researchgate.net/publication/3908304">https://www.researchgate.net/publication/3908304</a>.
041 *
042 * @param <States> Number of states.
043 * @param <Inputs> Number of inputs.
044 * @param <Outputs> Number of outputs.
045 */
046public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
047    implements KalmanTypeFilter<States, Inputs, Outputs> {
048  private final Nat<States> m_states;
049  private final Nat<Outputs> m_outputs;
050
051  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
052  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
053
054  private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX;
055  private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY;
056  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX;
057  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
058  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
059
060  private Matrix<States, N1> m_xHat;
061  private Matrix<States, States> m_S;
062  private final Matrix<States, States> m_contQ;
063  private final Matrix<Outputs, Outputs> m_contR;
064  private Matrix<States, ?> m_sigmasF;
065  private double m_dtSeconds;
066
067  private final MerweScaledSigmaPoints<States> m_pts;
068
069  /**
070   * Constructs an Unscented Kalman Filter.
071   *
072   * <p>See <a
073   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
074   * for how to select the standard deviations.
075   *
076   * @param states A Nat representing the number of states.
077   * @param outputs A Nat representing the number of outputs.
078   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
079   * @param h A vector-valued function of x and u that returns the measurement vector.
080   * @param stateStdDevs Standard deviations of model states.
081   * @param measurementStdDevs Standard deviations of measurements.
082   * @param nominalDtSeconds Nominal discretization timestep.
083   */
084  public UnscentedKalmanFilter(
085      Nat<States> states,
086      Nat<Outputs> outputs,
087      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
088      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
089      Matrix<States, N1> stateStdDevs,
090      Matrix<Outputs, N1> measurementStdDevs,
091      double nominalDtSeconds) {
092    this(
093        states,
094        outputs,
095        f,
096        h,
097        stateStdDevs,
098        measurementStdDevs,
099        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
100        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
101        Matrix::minus,
102        Matrix::minus,
103        Matrix::plus,
104        nominalDtSeconds);
105  }
106
107  /**
108   * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
109   * custom functions for arithmetic can be useful if you have angles in the state or measurements,
110   * because they allow you to correctly account for the modular nature of angle arithmetic.
111   *
112   * <p>See <a
113   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
114   * for how to select the standard deviations.
115   *
116   * @param states A Nat representing the number of states.
117   * @param outputs A Nat representing the number of outputs.
118   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
119   * @param h A vector-valued function of x and u that returns the measurement vector.
120   * @param stateStdDevs Standard deviations of model states.
121   * @param measurementStdDevs Standard deviations of measurements.
122   * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a
123   *     given set of weights.
124   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
125   *     a given set of weights.
126   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
127   *     subtracts them.)
128   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
129   *     subtracts them.)
130   * @param addFuncX A function that adds two state vectors.
131   * @param nominalDtSeconds Nominal discretization timestep.
132   */
133  public UnscentedKalmanFilter(
134      Nat<States> states,
135      Nat<Outputs> outputs,
136      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
137      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
138      Matrix<States, N1> stateStdDevs,
139      Matrix<Outputs, N1> measurementStdDevs,
140      BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
141      BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
142      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
143      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
144      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
145      double nominalDtSeconds) {
146    this.m_states = states;
147    this.m_outputs = outputs;
148
149    m_f = f;
150    m_h = h;
151
152    m_meanFuncX = meanFuncX;
153    m_meanFuncY = meanFuncY;
154    m_residualFuncX = residualFuncX;
155    m_residualFuncY = residualFuncY;
156    m_addFuncX = addFuncX;
157
158    m_dtSeconds = nominalDtSeconds;
159
160    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
161    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
162
163    m_pts = new MerweScaledSigmaPoints<>(states);
164
165    reset();
166  }
167
168  static <S extends Num, C extends Num>
169      Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
170          Nat<S> s,
171          Nat<C> dim,
172          Matrix<C, ?> sigmas,
173          Matrix<?, N1> Wm,
174          Matrix<?, N1> Wc,
175          BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
176          BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
177          Matrix<C, C> squareRootR) {
178    if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
179      throw new IllegalArgumentException(
180          "Sigmas must be covDim by 2 * states + 1! Got "
181              + sigmas.getNumRows()
182              + " by "
183              + sigmas.getNumCols());
184    }
185
186    if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) {
187      throw new IllegalArgumentException(
188          "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
189    }
190
191    if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
192      throw new IllegalArgumentException(
193          "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
194    }
195
196    // New mean is usually just the sum of the sigmas * weight:
197    //       n
198    // dot = Σ W[k] Xᵢ[k]
199    //      k=1
200    Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
201
202    Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
203    for (int i = 0; i < 2 * s.getNum(); i++) {
204      Sbar.setColumn(
205          i,
206          residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0))));
207    }
208    Sbar.assignBlock(0, 2 * s.getNum(), squareRootR);
209
210    QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM();
211    var qrStorage = Sbar.transpose().getStorage();
212
213    if (!qr.decompose(qrStorage.getDDRM())) {
214      throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage);
215    }
216
217    Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)));
218    newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false);
219
220    return new Pair<>(x, newS);
221  }
222
223  /**
224   * Returns the square-root error covariance matrix S.
225   *
226   * @return the square-root error covariance matrix S.
227   */
228  public Matrix<States, States> getS() {
229    return m_S;
230  }
231
232  /**
233   * Returns an element of the square-root error covariance matrix S.
234   *
235   * @param row Row of S.
236   * @param col Column of S.
237   * @return the value of the square-root error covariance matrix S at (i, j).
238   */
239  public double getS(int row, int col) {
240    return m_S.get(row, col);
241  }
242
243  /**
244   * Sets the entire square-root error covariance matrix S.
245   *
246   * @param newS The new value of S to use.
247   */
248  public void setS(Matrix<States, States> newS) {
249    m_S = newS;
250  }
251
252  /**
253   * Returns the reconstructed error covariance matrix P.
254   *
255   * @return the error covariance matrix P.
256   */
257  @Override
258  public Matrix<States, States> getP() {
259    return m_S.transpose().times(m_S);
260  }
261
262  /**
263   * Returns an element of the error covariance matrix P.
264   *
265   * @param row Row of P.
266   * @param col Column of P.
267   * @return the value of the error covariance matrix P at (i, j).
268   * @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported
269   */
270  @Override
271  public double getP(int row, int col) {
272    throw new UnsupportedOperationException(
273        "indexing into the reconstructed P matrix is not supported");
274  }
275
276  /**
277   * Sets the entire error covariance matrix P.
278   *
279   * @param newP The new value of P to use.
280   */
281  @Override
282  public void setP(Matrix<States, States> newP) {
283    m_S = newP.lltDecompose(false);
284  }
285
286  /**
287   * Returns the state estimate x-hat.
288   *
289   * @return the state estimate x-hat.
290   */
291  @Override
292  public Matrix<States, N1> getXhat() {
293    return m_xHat;
294  }
295
296  /**
297   * Returns an element of the state estimate x-hat.
298   *
299   * @param row Row of x-hat.
300   * @return the value of the state estimate x-hat at 'i'.
301   */
302  @Override
303  public double getXhat(int row) {
304    return m_xHat.get(row, 0);
305  }
306
307  /**
308   * Set initial state estimate x-hat.
309   *
310   * @param xHat The state estimate x-hat.
311   */
312  @Override
313  public void setXhat(Matrix<States, N1> xHat) {
314    m_xHat = xHat;
315  }
316
317  /**
318   * Set an element of the initial state estimate x-hat.
319   *
320   * @param row Row of x-hat.
321   * @param value Value for element of x-hat.
322   */
323  @Override
324  public void setXhat(int row, double value) {
325    m_xHat.set(row, 0, value);
326  }
327
328  /** Resets the observer. */
329  @Override
330  public final void reset() {
331    m_xHat = new Matrix<>(m_states, Nat.N1());
332    m_S = new Matrix<>(m_states, m_states);
333    m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
334  }
335
336  /**
337   * Project the model into the future with a new control input u.
338   *
339   * @param u New control input from controller.
340   * @param dtSeconds Timestep for prediction.
341   */
342  @Override
343  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
344    // Discretize Q before projecting mean and covariance forward
345    Matrix<States, States> contA =
346        NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
347    var discQ = Discretization.discretizeAQ(contA, m_contQ, dtSeconds).getSecond();
348    var squareRootDiscQ = discQ.lltDecompose(true);
349
350    var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
351
352    for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
353      Matrix<States, N1> x = sigmas.extractColumnVector(i);
354
355      m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds));
356    }
357
358    var ret =
359        squareRootUnscentedTransform(
360            m_states,
361            m_states,
362            m_sigmasF,
363            m_pts.getWm(),
364            m_pts.getWc(),
365            m_meanFuncX,
366            m_residualFuncX,
367            squareRootDiscQ);
368
369    m_xHat = ret.getFirst();
370    m_S = ret.getSecond();
371    m_dtSeconds = dtSeconds;
372  }
373
374  /**
375   * Correct the state estimate x-hat using the measurements in y.
376   *
377   * @param u Same control input used in the predict step.
378   * @param y Measurement vector.
379   */
380  @Override
381  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
382    correct(
383        m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
384  }
385
386  /**
387   * Correct the state estimate x-hat using the measurements in y.
388   *
389   * <p>This is useful for when the measurement noise covariances vary.
390   *
391   * @param u Same control input used in the predict step.
392   * @param y Measurement vector.
393   * @param R Continuous measurement noise covariance matrix.
394   */
395  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
396    correct(m_outputs, u, y, m_h, R, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
397  }
398
399  /**
400   * Correct the state estimate x-hat using the measurements in y.
401   *
402   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
403   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
404   * of this function).
405   *
406   * @param <R> Number of measurements in y.
407   * @param rows Number of rows in y.
408   * @param u Same control input used in the predict step.
409   * @param y Measurement vector.
410   * @param h A vector-valued function of x and u that returns the measurement vector.
411   * @param R Continuous measurement noise covariance matrix.
412   */
413  public <R extends Num> void correct(
414      Nat<R> rows,
415      Matrix<Inputs, N1> u,
416      Matrix<R, N1> y,
417      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
418      Matrix<R, R> R) {
419    BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY =
420        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm));
421    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX =
422        Matrix::minus;
423    BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus;
424    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus;
425    correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX);
426  }
427
428  /**
429   * Correct the state estimate x-hat using the measurements in y.
430   *
431   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
432   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
433   * of this function).
434   *
435   * @param <R> Number of measurements in y.
436   * @param rows Number of rows in y.
437   * @param u Same control input used in the predict step.
438   * @param y Measurement vector.
439   * @param h A vector-valued function of x and u that returns the measurement vector.
440   * @param R Continuous measurement noise covariance matrix.
441   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
442   *     a given set of weights.
443   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
444   *     subtracts them.)
445   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
446   *     subtracts them.)
447   * @param addFuncX A function that adds two state vectors.
448   */
449  public <R extends Num> void correct(
450      Nat<R> rows,
451      Matrix<Inputs, N1> u,
452      Matrix<R, N1> y,
453      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
454      Matrix<R, R> R,
455      BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY,
456      BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY,
457      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
458      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
459    final var discR = Discretization.discretizeR(R, m_dtSeconds);
460    final var squareRootDiscR = discR.lltDecompose(true);
461
462    // Transform sigma points into measurement space
463    Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
464    var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
465    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
466      Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
467      sigmasH.setColumn(i, hRet);
468    }
469
470    // Mean and covariance of prediction passed through unscented transform
471    var transRet =
472        squareRootUnscentedTransform(
473            m_states,
474            rows,
475            sigmasH,
476            m_pts.getWm(),
477            m_pts.getWc(),
478            meanFuncY,
479            residualFuncY,
480            squareRootDiscR);
481    var yHat = transRet.getFirst();
482    var Sy = transRet.getSecond();
483
484    // Compute cross covariance of the state and the measurements
485    Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
486    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
487      // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
488      var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
489      var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
490
491      Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
492    }
493
494    // K = (P_{xy} / S_yᵀ) / S_y
495    // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
496    // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
497    Matrix<States, R> K =
498        Sy.transpose()
499            .solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose()))
500            .transpose();
501
502    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
503    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
504
505    Matrix<States, R> U = K.times(Sy);
506    for (int i = 0; i < rows.getNum(); i++) {
507      m_S.rankUpdate(U.extractColumnVector(i), -1, false);
508    }
509  }
510}