001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.Pair;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM;
018import org.ejml.simple.SimpleMatrix;
019
020/**
021 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
022 * true system state. This is useful because many states cannot be measured directly as a result of
023 * sensor noise, or because the state is "hidden".
024 *
025 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
026 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
027 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
028 * amount of the difference between the actual measurements and the measurements predicted by the
029 * model.
030 *
031 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the
032 * error covariance using sigma points chosen to approximate the true probability distribution.
033 *
034 * <p>For more on the underlying math, read <a
035 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
036 * chapter 9 "Stochastic control theory".
037 *
038 * <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more
039 * information about the SR-UKF, see <a
040 * href="https://www.researchgate.net/publication/3908304">https://www.researchgate.net/publication/3908304</a>.
041 */
042public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
043    implements KalmanTypeFilter<States, Inputs, Outputs> {
044  private final Nat<States> m_states;
045  private final Nat<Outputs> m_outputs;
046
047  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
048  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
049
050  private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX;
051  private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY;
052  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX;
053  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
054  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
055
056  private Matrix<States, N1> m_xHat;
057  private Matrix<States, States> m_S;
058  private final Matrix<States, States> m_contQ;
059  private final Matrix<Outputs, Outputs> m_contR;
060  private Matrix<States, ?> m_sigmasF;
061  private double m_dtSeconds;
062
063  private final MerweScaledSigmaPoints<States> m_pts;
064
065  /**
066   * Constructs an Unscented Kalman Filter.
067   *
068   * <p>See <a
069   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
070   * for how to select the standard deviations.
071   *
072   * @param states A Nat representing the number of states.
073   * @param outputs A Nat representing the number of outputs.
074   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
075   * @param h A vector-valued function of x and u that returns the measurement vector.
076   * @param stateStdDevs Standard deviations of model states.
077   * @param measurementStdDevs Standard deviations of measurements.
078   * @param nominalDtSeconds Nominal discretization timestep.
079   */
080  public UnscentedKalmanFilter(
081      Nat<States> states,
082      Nat<Outputs> outputs,
083      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
084      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
085      Matrix<States, N1> stateStdDevs,
086      Matrix<Outputs, N1> measurementStdDevs,
087      double nominalDtSeconds) {
088    this(
089        states,
090        outputs,
091        f,
092        h,
093        stateStdDevs,
094        measurementStdDevs,
095        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
096        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
097        Matrix::minus,
098        Matrix::minus,
099        Matrix::plus,
100        nominalDtSeconds);
101  }
102
103  /**
104   * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
105   * custom functions for arithmetic can be useful if you have angles in the state or measurements,
106   * because they allow you to correctly account for the modular nature of angle arithmetic.
107   *
108   * <p>See <a
109   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
110   * for how to select the standard deviations.
111   *
112   * @param states A Nat representing the number of states.
113   * @param outputs A Nat representing the number of outputs.
114   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
115   * @param h A vector-valued function of x and u that returns the measurement vector.
116   * @param stateStdDevs Standard deviations of model states.
117   * @param measurementStdDevs Standard deviations of measurements.
118   * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a
119   *     given set of weights.
120   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
121   *     a given set of weights.
122   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
123   *     subtracts them.)
124   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
125   *     subtracts them.)
126   * @param addFuncX A function that adds two state vectors.
127   * @param nominalDtSeconds Nominal discretization timestep.
128   */
129  public UnscentedKalmanFilter(
130      Nat<States> states,
131      Nat<Outputs> outputs,
132      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
133      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
134      Matrix<States, N1> stateStdDevs,
135      Matrix<Outputs, N1> measurementStdDevs,
136      BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
137      BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
138      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
139      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
140      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
141      double nominalDtSeconds) {
142    this.m_states = states;
143    this.m_outputs = outputs;
144
145    m_f = f;
146    m_h = h;
147
148    m_meanFuncX = meanFuncX;
149    m_meanFuncY = meanFuncY;
150    m_residualFuncX = residualFuncX;
151    m_residualFuncY = residualFuncY;
152    m_addFuncX = addFuncX;
153
154    m_dtSeconds = nominalDtSeconds;
155
156    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
157    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
158
159    m_pts = new MerweScaledSigmaPoints<>(states);
160
161    reset();
162  }
163
164  static <S extends Num, C extends Num>
165      Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
166          Nat<S> s,
167          Nat<C> dim,
168          Matrix<C, ?> sigmas,
169          Matrix<?, N1> Wm,
170          Matrix<?, N1> Wc,
171          BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
172          BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
173          Matrix<C, C> squareRootR) {
174    if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
175      throw new IllegalArgumentException(
176          "Sigmas must be covDim by 2 * states + 1! Got "
177              + sigmas.getNumRows()
178              + " by "
179              + sigmas.getNumCols());
180    }
181
182    if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) {
183      throw new IllegalArgumentException(
184          "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
185    }
186
187    if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
188      throw new IllegalArgumentException(
189          "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
190    }
191
192    // New mean is usually just the sum of the sigmas * weight:
193    //       n
194    // dot = Σ W[k] Xᵢ[k]
195    //      k=1
196    Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
197
198    Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
199    for (int i = 0; i < 2 * s.getNum(); i++) {
200      Sbar.setColumn(
201          i,
202          residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0))));
203    }
204    Sbar.assignBlock(0, 2 * s.getNum(), squareRootR);
205
206    QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM();
207    var qrStorage = Sbar.transpose().getStorage();
208
209    if (!qr.decompose(qrStorage.getDDRM())) {
210      throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage);
211    }
212
213    Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)));
214    newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false);
215
216    return new Pair<>(x, newS);
217  }
218
219  /**
220   * Returns the square-root error covariance matrix S.
221   *
222   * @return the square-root error covariance matrix S.
223   */
224  public Matrix<States, States> getS() {
225    return m_S;
226  }
227
228  /**
229   * Returns an element of the square-root error covariance matrix S.
230   *
231   * @param row Row of S.
232   * @param col Column of S.
233   * @return the value of the square-root error covariance matrix S at (i, j).
234   */
235  public double getS(int row, int col) {
236    return m_S.get(row, col);
237  }
238
239  /**
240   * Sets the entire square-root error covariance matrix S.
241   *
242   * @param newS The new value of S to use.
243   */
244  public void setS(Matrix<States, States> newS) {
245    m_S = newS;
246  }
247
248  /**
249   * Returns the reconstructed error covariance matrix P.
250   *
251   * @return the error covariance matrix P.
252   */
253  @Override
254  public Matrix<States, States> getP() {
255    return m_S.transpose().times(m_S);
256  }
257
258  /**
259   * Returns an element of the error covariance matrix P.
260   *
261   * @param row Row of P.
262   * @param col Column of P.
263   * @return the value of the error covariance matrix P at (i, j).
264   * @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported
265   */
266  @Override
267  public double getP(int row, int col) {
268    throw new UnsupportedOperationException(
269        "indexing into the reconstructed P matrix is not supported");
270  }
271
272  /**
273   * Sets the entire error covariance matrix P.
274   *
275   * @param newP The new value of P to use.
276   */
277  @Override
278  public void setP(Matrix<States, States> newP) {
279    m_S = newP.lltDecompose(false);
280  }
281
282  /**
283   * Returns the state estimate x-hat.
284   *
285   * @return the state estimate x-hat.
286   */
287  @Override
288  public Matrix<States, N1> getXhat() {
289    return m_xHat;
290  }
291
292  /**
293   * Returns an element of the state estimate x-hat.
294   *
295   * @param row Row of x-hat.
296   * @return the value of the state estimate x-hat at 'i'.
297   */
298  @Override
299  public double getXhat(int row) {
300    return m_xHat.get(row, 0);
301  }
302
303  /**
304   * Set initial state estimate x-hat.
305   *
306   * @param xHat The state estimate x-hat.
307   */
308  @Override
309  public void setXhat(Matrix<States, N1> xHat) {
310    m_xHat = xHat;
311  }
312
313  /**
314   * Set an element of the initial state estimate x-hat.
315   *
316   * @param row Row of x-hat.
317   * @param value Value for element of x-hat.
318   */
319  @Override
320  public void setXhat(int row, double value) {
321    m_xHat.set(row, 0, value);
322  }
323
324  /** Resets the observer. */
325  @Override
326  public final void reset() {
327    m_xHat = new Matrix<>(m_states, Nat.N1());
328    m_S = new Matrix<>(m_states, m_states);
329    m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
330  }
331
332  /**
333   * Project the model into the future with a new control input u.
334   *
335   * @param u New control input from controller.
336   * @param dtSeconds Timestep for prediction.
337   */
338  @Override
339  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
340    // Discretize Q before projecting mean and covariance forward
341    Matrix<States, States> contA =
342        NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
343    var discQ = Discretization.discretizeAQ(contA, m_contQ, dtSeconds).getSecond();
344    var squareRootDiscQ = discQ.lltDecompose(true);
345
346    var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
347
348    for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
349      Matrix<States, N1> x = sigmas.extractColumnVector(i);
350
351      m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds));
352    }
353
354    var ret =
355        squareRootUnscentedTransform(
356            m_states,
357            m_states,
358            m_sigmasF,
359            m_pts.getWm(),
360            m_pts.getWc(),
361            m_meanFuncX,
362            m_residualFuncX,
363            squareRootDiscQ);
364
365    m_xHat = ret.getFirst();
366    m_S = ret.getSecond();
367    m_dtSeconds = dtSeconds;
368  }
369
370  /**
371   * Correct the state estimate x-hat using the measurements in y.
372   *
373   * @param u Same control input used in the predict step.
374   * @param y Measurement vector.
375   */
376  @Override
377  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
378    correct(
379        m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
380  }
381
382  /**
383   * Correct the state estimate x-hat using the measurements in y.
384   *
385   * <p>This is useful for when the measurement noise covariances vary.
386   *
387   * @param u Same control input used in the predict step.
388   * @param y Measurement vector.
389   * @param R Continuous measurement noise covariance matrix.
390   */
391  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
392    correct(m_outputs, u, y, m_h, R, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
393  }
394
395  /**
396   * Correct the state estimate x-hat using the measurements in y.
397   *
398   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
399   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
400   * of this function).
401   *
402   * @param <R> Number of measurements in y.
403   * @param rows Number of rows in y.
404   * @param u Same control input used in the predict step.
405   * @param y Measurement vector.
406   * @param h A vector-valued function of x and u that returns the measurement vector.
407   * @param R Continuous measurement noise covariance matrix.
408   */
409  public <R extends Num> void correct(
410      Nat<R> rows,
411      Matrix<Inputs, N1> u,
412      Matrix<R, N1> y,
413      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
414      Matrix<R, R> R) {
415    BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY =
416        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm));
417    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX =
418        Matrix::minus;
419    BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus;
420    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus;
421    correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX);
422  }
423
424  /**
425   * Correct the state estimate x-hat using the measurements in y.
426   *
427   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
428   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
429   * of this function).
430   *
431   * @param <R> Number of measurements in y.
432   * @param rows Number of rows in y.
433   * @param u Same control input used in the predict step.
434   * @param y Measurement vector.
435   * @param h A vector-valued function of x and u that returns the measurement vector.
436   * @param R Continuous measurement noise covariance matrix.
437   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
438   *     a given set of weights.
439   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
440   *     subtracts them.)
441   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
442   *     subtracts them.)
443   * @param addFuncX A function that adds two state vectors.
444   */
445  public <R extends Num> void correct(
446      Nat<R> rows,
447      Matrix<Inputs, N1> u,
448      Matrix<R, N1> y,
449      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
450      Matrix<R, R> R,
451      BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY,
452      BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY,
453      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
454      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
455    final var discR = Discretization.discretizeR(R, m_dtSeconds);
456    final var squareRootDiscR = discR.lltDecompose(true);
457
458    // Transform sigma points into measurement space
459    Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
460    var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
461    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
462      Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
463      sigmasH.setColumn(i, hRet);
464    }
465
466    // Mean and covariance of prediction passed through unscented transform
467    var transRet =
468        squareRootUnscentedTransform(
469            m_states,
470            rows,
471            sigmasH,
472            m_pts.getWm(),
473            m_pts.getWc(),
474            meanFuncY,
475            residualFuncY,
476            squareRootDiscR);
477    var yHat = transRet.getFirst();
478    var Sy = transRet.getSecond();
479
480    // Compute cross covariance of the state and the measurements
481    Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
482    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
483      // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
484      var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
485      var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
486
487      Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
488    }
489
490    // K = (P_{xy} / S_yᵀ) / S_y
491    // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
492    // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
493    Matrix<States, R> K =
494        Sy.transpose()
495            .solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose()))
496            .transpose();
497
498    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
499    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
500
501    Matrix<States, R> U = K.times(Sy);
502    for (int i = 0; i < rows.getNum(); i++) {
503      m_S.rankUpdate(U.extractColumnVector(i), -1, false);
504    }
505  }
506}