001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.DARE;
008import edu.wpi.first.math.Matrix;
009import edu.wpi.first.math.Nat;
010import edu.wpi.first.math.Num;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017
018/**
019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
020 * true system state. This is useful because many states cannot be measured directly as a result of
021 * sensor noise, or because the state is "hidden".
022 *
023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
026 * amount of the difference between the actual measurements and the measurements predicted by the
027 * model.
028 *
029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the
030 * error covariance by linearizing the models around the state estimate, then applying the linear
031 * Kalman filter equations.
032 *
033 * <p>For more on the underlying math, read <a
034 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
035 * chapter 9 "Stochastic control theory".
036 */
037public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
038    implements KalmanTypeFilter<States, Inputs, Outputs> {
039  private final Nat<States> m_states;
040  private final Nat<Outputs> m_outputs;
041
042  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
043
044  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
045
046  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
047  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
048
049  private final Matrix<States, States> m_contQ;
050  private final Matrix<States, States> m_initP;
051  private final Matrix<Outputs, Outputs> m_contR;
052
053  private Matrix<States, N1> m_xHat;
054
055  private Matrix<States, States> m_P;
056
057  private double m_dtSeconds;
058
059  /**
060   * Constructs an extended Kalman filter.
061   *
062   * <p>See <a
063   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
064   * for how to select the standard deviations.
065   *
066   * @param states a Nat representing the number of states.
067   * @param inputs a Nat representing the number of inputs.
068   * @param outputs a Nat representing the number of outputs.
069   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
070   * @param h A vector-valued function of x and u that returns the measurement vector.
071   * @param stateStdDevs Standard deviations of model states.
072   * @param measurementStdDevs Standard deviations of measurements.
073   * @param dtSeconds Nominal discretization timestep.
074   */
075  public ExtendedKalmanFilter(
076      Nat<States> states,
077      Nat<Inputs> inputs,
078      Nat<Outputs> outputs,
079      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
080      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
081      Matrix<States, N1> stateStdDevs,
082      Matrix<Outputs, N1> measurementStdDevs,
083      double dtSeconds) {
084    this(
085        states,
086        inputs,
087        outputs,
088        f,
089        h,
090        stateStdDevs,
091        measurementStdDevs,
092        Matrix::minus,
093        Matrix::plus,
094        dtSeconds);
095  }
096
097  /**
098   * Constructs an extended Kalman filter.
099   *
100   * <p>See <a
101   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
102   * for how to select the standard deviations.
103   *
104   * @param states a Nat representing the number of states.
105   * @param inputs a Nat representing the number of inputs.
106   * @param outputs a Nat representing the number of outputs.
107   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
108   * @param h A vector-valued function of x and u that returns the measurement vector.
109   * @param stateStdDevs Standard deviations of model states.
110   * @param measurementStdDevs Standard deviations of measurements.
111   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
112   *     subtracts them.)
113   * @param addFuncX A function that adds two state vectors.
114   * @param dtSeconds Nominal discretization timestep.
115   */
116  public ExtendedKalmanFilter(
117      Nat<States> states,
118      Nat<Inputs> inputs,
119      Nat<Outputs> outputs,
120      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
121      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
122      Matrix<States, N1> stateStdDevs,
123      Matrix<Outputs, N1> measurementStdDevs,
124      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
125      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
126      double dtSeconds) {
127    m_states = states;
128    m_outputs = outputs;
129
130    m_f = f;
131    m_h = h;
132
133    m_residualFuncY = residualFuncY;
134    m_addFuncX = addFuncX;
135
136    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
137    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
138    m_dtSeconds = dtSeconds;
139
140    reset();
141
142    final var contA =
143        NumericalJacobian.numericalJacobianX(
144            states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
145    final var C =
146        NumericalJacobian.numericalJacobianX(
147            outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
148
149    final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds);
150    final var discA = discPair.getFirst();
151    final var discQ = discPair.getSecond();
152
153    final var discR = Discretization.discretizeR(m_contR, dtSeconds);
154
155    if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) {
156      m_initP = DARE.dare(discA.transpose(), C.transpose(), discQ, discR);
157    } else {
158      m_initP = new Matrix<>(states, states);
159    }
160
161    m_P = m_initP;
162  }
163
164  /**
165   * Returns the error covariance matrix P.
166   *
167   * @return the error covariance matrix P.
168   */
169  @Override
170  public Matrix<States, States> getP() {
171    return m_P;
172  }
173
174  /**
175   * Returns an element of the error covariance matrix P.
176   *
177   * @param row Row of P.
178   * @param col Column of P.
179   * @return the value of the error covariance matrix P at (i, j).
180   */
181  @Override
182  public double getP(int row, int col) {
183    return m_P.get(row, col);
184  }
185
186  /**
187   * Sets the entire error covariance matrix P.
188   *
189   * @param newP The new value of P to use.
190   */
191  @Override
192  public void setP(Matrix<States, States> newP) {
193    m_P = newP;
194  }
195
196  /**
197   * Returns the state estimate x-hat.
198   *
199   * @return the state estimate x-hat.
200   */
201  @Override
202  public Matrix<States, N1> getXhat() {
203    return m_xHat;
204  }
205
206  /**
207   * Returns an element of the state estimate x-hat.
208   *
209   * @param row Row of x-hat.
210   * @return the value of the state estimate x-hat at that row.
211   */
212  @Override
213  public double getXhat(int row) {
214    return m_xHat.get(row, 0);
215  }
216
217  /**
218   * Set initial state estimate x-hat.
219   *
220   * @param xHat The state estimate x-hat.
221   */
222  @Override
223  public void setXhat(Matrix<States, N1> xHat) {
224    m_xHat = xHat;
225  }
226
227  /**
228   * Set an element of the initial state estimate x-hat.
229   *
230   * @param row Row of x-hat.
231   * @param value Value for element of x-hat.
232   */
233  @Override
234  public void setXhat(int row, double value) {
235    m_xHat.set(row, 0, value);
236  }
237
238  @Override
239  public final void reset() {
240    m_xHat = new Matrix<>(m_states, Nat.N1());
241    m_P = m_initP;
242  }
243
244  /**
245   * Project the model into the future with a new control input u.
246   *
247   * @param u New control input from controller.
248   * @param dtSeconds Timestep for prediction.
249   */
250  @Override
251  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
252    predict(u, m_f, dtSeconds);
253  }
254
255  /**
256   * Project the model into the future with a new control input u.
257   *
258   * @param u New control input from controller.
259   * @param f The function used to linearize the model.
260   * @param dtSeconds Timestep for prediction.
261   */
262  public void predict(
263      Matrix<Inputs, N1> u,
264      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
265      double dtSeconds) {
266    // Find continuous A
267    final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
268
269    // Find discrete A and Q
270    final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds);
271    final var discA = discPair.getFirst();
272    final var discQ = discPair.getSecond();
273
274    m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds);
275
276    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
277    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
278
279    m_dtSeconds = dtSeconds;
280  }
281
282  /**
283   * Correct the state estimate x-hat using the measurements in y.
284   *
285   * @param u Same control input used in the predict step.
286   * @param y Measurement vector.
287   */
288  @Override
289  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
290    correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
291  }
292
293  /**
294   * Correct the state estimate x-hat using the measurements in y.
295   *
296   * <p>This is useful for when the measurement noise covariances vary.
297   *
298   * @param u Same control input used in the predict step.
299   * @param y Measurement vector.
300   * @param R Continuous measurement noise covariance matrix.
301   */
302  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
303    correct(m_outputs, u, y, m_h, R, m_residualFuncY, m_addFuncX);
304  }
305
306  /**
307   * Correct the state estimate x-hat using the measurements in y.
308   *
309   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
310   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
311   * of this function).
312   *
313   * @param <Rows> Number of rows in the result of f(x, u).
314   * @param rows Number of rows in the result of f(x, u).
315   * @param u Same control input used in the predict step.
316   * @param y Measurement vector.
317   * @param h A vector-valued function of x and u that returns the measurement vector.
318   * @param R Continuous measurement noise covariance matrix.
319   */
320  public <Rows extends Num> void correct(
321      Nat<Rows> rows,
322      Matrix<Inputs, N1> u,
323      Matrix<Rows, N1> y,
324      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
325      Matrix<Rows, Rows> R) {
326    correct(rows, u, y, h, R, Matrix::minus, Matrix::plus);
327  }
328
329  /**
330   * Correct the state estimate x-hat using the measurements in y.
331   *
332   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
333   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
334   * of this function).
335   *
336   * @param <Rows> Number of rows in the result of f(x, u).
337   * @param rows Number of rows in the result of f(x, u).
338   * @param u Same control input used in the predict step.
339   * @param y Measurement vector.
340   * @param h A vector-valued function of x and u that returns the measurement vector.
341   * @param R Continuous measurement noise covariance matrix.
342   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
343   *     subtracts them.)
344   * @param addFuncX A function that adds two state vectors.
345   */
346  public <Rows extends Num> void correct(
347      Nat<Rows> rows,
348      Matrix<Inputs, N1> u,
349      Matrix<Rows, N1> y,
350      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
351      Matrix<Rows, Rows> R,
352      BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY,
353      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
354    final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
355    final var discR = Discretization.discretizeR(R, m_dtSeconds);
356
357    final var S = C.times(m_P).times(C.transpose()).plus(discR);
358
359    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
360    // efficiently.
361    //
362    // K = PCᵀS⁻¹
363    // KS = PCᵀ
364    // (KS)ᵀ = (PCᵀ)ᵀ
365    // SᵀKᵀ = CPᵀ
366    //
367    // The solution of Ax = b can be found via x = A.solve(b).
368    //
369    // Kᵀ = Sᵀ.solve(CPᵀ)
370    // K = (Sᵀ.solve(CPᵀ))ᵀ
371    final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
372
373    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
374    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u))));
375
376    // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
377    // Use Joseph form for numerical stability
378    m_P =
379        Matrix.eye(m_states)
380            .minus(K.times(C))
381            .times(m_P)
382            .times(Matrix.eye(m_states).minus(K.times(C)).transpose())
383            .plus(K.times(discR).times(K.transpose()));
384  }
385}