001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.DARE;
008import edu.wpi.first.math.Matrix;
009import edu.wpi.first.math.Nat;
010import edu.wpi.first.math.Num;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017
018/**
019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
020 * true system state. This is useful because many states cannot be measured directly as a result of
021 * sensor noise, or because the state is "hidden".
022 *
023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
026 * amount of the difference between the actual measurements and the measurements predicted by the
027 * model.
028 *
029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the
030 * error covariance by linearizing the models around the state estimate, then applying the linear
031 * Kalman filter equations.
032 *
033 * <p>For more on the underlying math, read <a
034 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
035 * chapter 9 "Stochastic control theory".
036 *
037 * @param <States> Number of states.
038 * @param <Inputs> Number of inputs.
039 * @param <Outputs> Number of outputs.
040 */
041public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
042    implements KalmanTypeFilter<States, Inputs, Outputs> {
043  private final Nat<States> m_states;
044  private final Nat<Outputs> m_outputs;
045
046  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
047
048  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
049
050  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
051  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
052
053  private final Matrix<States, States> m_contQ;
054  private final Matrix<States, States> m_initP;
055  private final Matrix<Outputs, Outputs> m_contR;
056
057  private Matrix<States, N1> m_xHat;
058
059  private Matrix<States, States> m_P;
060
061  private double m_dtSeconds;
062
063  /**
064   * Constructs an extended Kalman filter.
065   *
066   * <p>See <a
067   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
068   * for how to select the standard deviations.
069   *
070   * @param states a Nat representing the number of states.
071   * @param inputs a Nat representing the number of inputs.
072   * @param outputs a Nat representing the number of outputs.
073   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
074   * @param h A vector-valued function of x and u that returns the measurement vector.
075   * @param stateStdDevs Standard deviations of model states.
076   * @param measurementStdDevs Standard deviations of measurements.
077   * @param dtSeconds Nominal discretization timestep.
078   */
079  public ExtendedKalmanFilter(
080      Nat<States> states,
081      Nat<Inputs> inputs,
082      Nat<Outputs> outputs,
083      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
084      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
085      Matrix<States, N1> stateStdDevs,
086      Matrix<Outputs, N1> measurementStdDevs,
087      double dtSeconds) {
088    this(
089        states,
090        inputs,
091        outputs,
092        f,
093        h,
094        stateStdDevs,
095        measurementStdDevs,
096        Matrix::minus,
097        Matrix::plus,
098        dtSeconds);
099  }
100
101  /**
102   * Constructs an extended Kalman filter.
103   *
104   * <p>See <a
105   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
106   * for how to select the standard deviations.
107   *
108   * @param states a Nat representing the number of states.
109   * @param inputs a Nat representing the number of inputs.
110   * @param outputs a Nat representing the number of outputs.
111   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
112   * @param h A vector-valued function of x and u that returns the measurement vector.
113   * @param stateStdDevs Standard deviations of model states.
114   * @param measurementStdDevs Standard deviations of measurements.
115   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
116   *     subtracts them.)
117   * @param addFuncX A function that adds two state vectors.
118   * @param dtSeconds Nominal discretization timestep.
119   */
120  public ExtendedKalmanFilter(
121      Nat<States> states,
122      Nat<Inputs> inputs,
123      Nat<Outputs> outputs,
124      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
125      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
126      Matrix<States, N1> stateStdDevs,
127      Matrix<Outputs, N1> measurementStdDevs,
128      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
129      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
130      double dtSeconds) {
131    m_states = states;
132    m_outputs = outputs;
133
134    m_f = f;
135    m_h = h;
136
137    m_residualFuncY = residualFuncY;
138    m_addFuncX = addFuncX;
139
140    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
141    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
142    m_dtSeconds = dtSeconds;
143
144    reset();
145
146    final var contA =
147        NumericalJacobian.numericalJacobianX(
148            states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
149    final var C =
150        NumericalJacobian.numericalJacobianX(
151            outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
152
153    final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds);
154    final var discA = discPair.getFirst();
155    final var discQ = discPair.getSecond();
156
157    final var discR = Discretization.discretizeR(m_contR, dtSeconds);
158
159    if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) {
160      m_initP = DARE.dare(discA.transpose(), C.transpose(), discQ, discR);
161    } else {
162      m_initP = new Matrix<>(states, states);
163    }
164
165    m_P = m_initP;
166  }
167
168  /**
169   * Returns the error covariance matrix P.
170   *
171   * @return the error covariance matrix P.
172   */
173  @Override
174  public Matrix<States, States> getP() {
175    return m_P;
176  }
177
178  /**
179   * Returns an element of the error covariance matrix P.
180   *
181   * @param row Row of P.
182   * @param col Column of P.
183   * @return the value of the error covariance matrix P at (i, j).
184   */
185  @Override
186  public double getP(int row, int col) {
187    return m_P.get(row, col);
188  }
189
190  /**
191   * Sets the entire error covariance matrix P.
192   *
193   * @param newP The new value of P to use.
194   */
195  @Override
196  public void setP(Matrix<States, States> newP) {
197    m_P = newP;
198  }
199
200  /**
201   * Returns the state estimate x-hat.
202   *
203   * @return the state estimate x-hat.
204   */
205  @Override
206  public Matrix<States, N1> getXhat() {
207    return m_xHat;
208  }
209
210  /**
211   * Returns an element of the state estimate x-hat.
212   *
213   * @param row Row of x-hat.
214   * @return the value of the state estimate x-hat at that row.
215   */
216  @Override
217  public double getXhat(int row) {
218    return m_xHat.get(row, 0);
219  }
220
221  /**
222   * Set initial state estimate x-hat.
223   *
224   * @param xHat The state estimate x-hat.
225   */
226  @Override
227  public void setXhat(Matrix<States, N1> xHat) {
228    m_xHat = xHat;
229  }
230
231  /**
232   * Set an element of the initial state estimate x-hat.
233   *
234   * @param row Row of x-hat.
235   * @param value Value for element of x-hat.
236   */
237  @Override
238  public void setXhat(int row, double value) {
239    m_xHat.set(row, 0, value);
240  }
241
242  @Override
243  public final void reset() {
244    m_xHat = new Matrix<>(m_states, Nat.N1());
245    m_P = m_initP;
246  }
247
248  /**
249   * Project the model into the future with a new control input u.
250   *
251   * @param u New control input from controller.
252   * @param dtSeconds Timestep for prediction.
253   */
254  @Override
255  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
256    predict(u, m_f, dtSeconds);
257  }
258
259  /**
260   * Project the model into the future with a new control input u.
261   *
262   * @param u New control input from controller.
263   * @param f The function used to linearize the model.
264   * @param dtSeconds Timestep for prediction.
265   */
266  public void predict(
267      Matrix<Inputs, N1> u,
268      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
269      double dtSeconds) {
270    // Find continuous A
271    final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
272
273    // Find discrete A and Q
274    final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds);
275    final var discA = discPair.getFirst();
276    final var discQ = discPair.getSecond();
277
278    m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds);
279
280    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
281    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
282
283    m_dtSeconds = dtSeconds;
284  }
285
286  /**
287   * Correct the state estimate x-hat using the measurements in y.
288   *
289   * @param u Same control input used in the predict step.
290   * @param y Measurement vector.
291   */
292  @Override
293  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
294    correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
295  }
296
297  /**
298   * Correct the state estimate x-hat using the measurements in y.
299   *
300   * <p>This is useful for when the measurement noise covariances vary.
301   *
302   * @param u Same control input used in the predict step.
303   * @param y Measurement vector.
304   * @param R Continuous measurement noise covariance matrix.
305   */
306  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
307    correct(m_outputs, u, y, m_h, R, m_residualFuncY, m_addFuncX);
308  }
309
310  /**
311   * Correct the state estimate x-hat using the measurements in y.
312   *
313   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
314   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
315   * of this function).
316   *
317   * @param <Rows> Number of rows in the result of f(x, u).
318   * @param rows Number of rows in the result of f(x, u).
319   * @param u Same control input used in the predict step.
320   * @param y Measurement vector.
321   * @param h A vector-valued function of x and u that returns the measurement vector.
322   * @param R Continuous measurement noise covariance matrix.
323   */
324  public <Rows extends Num> void correct(
325      Nat<Rows> rows,
326      Matrix<Inputs, N1> u,
327      Matrix<Rows, N1> y,
328      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
329      Matrix<Rows, Rows> R) {
330    correct(rows, u, y, h, R, Matrix::minus, Matrix::plus);
331  }
332
333  /**
334   * Correct the state estimate x-hat using the measurements in y.
335   *
336   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
337   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
338   * of this function).
339   *
340   * @param <Rows> Number of rows in the result of f(x, u).
341   * @param rows Number of rows in the result of f(x, u).
342   * @param u Same control input used in the predict step.
343   * @param y Measurement vector.
344   * @param h A vector-valued function of x and u that returns the measurement vector.
345   * @param R Continuous measurement noise covariance matrix.
346   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
347   *     subtracts them.)
348   * @param addFuncX A function that adds two state vectors.
349   */
350  public <Rows extends Num> void correct(
351      Nat<Rows> rows,
352      Matrix<Inputs, N1> u,
353      Matrix<Rows, N1> y,
354      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
355      Matrix<Rows, Rows> R,
356      BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY,
357      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
358    final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
359    final var discR = Discretization.discretizeR(R, m_dtSeconds);
360
361    final var S = C.times(m_P).times(C.transpose()).plus(discR);
362
363    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
364    // efficiently.
365    //
366    // K = PCᵀS⁻¹
367    // KS = PCᵀ
368    // (KS)ᵀ = (PCᵀ)ᵀ
369    // SᵀKᵀ = CPᵀ
370    //
371    // The solution of Ax = b can be found via x = A.solve(b).
372    //
373    // Kᵀ = Sᵀ.solve(CPᵀ)
374    // K = (Sᵀ.solve(CPᵀ))ᵀ
375    final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
376
377    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
378    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u))));
379
380    // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
381    // Use Joseph form for numerical stability
382    m_P =
383        Matrix.eye(m_states)
384            .minus(K.times(C))
385            .times(m_P)
386            .times(Matrix.eye(m_states).minus(K.times(C)).transpose())
387            .plus(K.times(discR).times(K.transpose()));
388  }
389}