001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.DARE;
008import edu.wpi.first.math.MathSharedStore;
009import edu.wpi.first.math.Matrix;
010import edu.wpi.first.math.Nat;
011import edu.wpi.first.math.Num;
012import edu.wpi.first.math.StateSpaceUtil;
013import edu.wpi.first.math.numbers.N1;
014import edu.wpi.first.math.system.Discretization;
015import edu.wpi.first.math.system.LinearSystem;
016
017/**
018 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
019 * true system state. This is useful because many states cannot be measured directly as a result of
020 * sensor noise, or because the state is "hidden".
021 *
022 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
023 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
024 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
025 * amount of the difference between the actual measurements and the measurements predicted by the
026 * model.
027 *
028 * <p>For more on the underlying math, read <a
029 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
030 * chapter 9 "Stochastic control theory".
031 */
032public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
033    implements KalmanTypeFilter<States, Inputs, Outputs> {
034  private final Nat<States> m_states;
035
036  private final LinearSystem<States, Inputs, Outputs> m_plant;
037  private Matrix<States, N1> m_xHat;
038  private Matrix<States, States> m_P;
039  private final Matrix<States, States> m_contQ;
040  private final Matrix<Outputs, Outputs> m_contR;
041  private double m_dtSeconds;
042
043  private final Matrix<States, States> m_initP;
044
045  /**
046   * Constructs a Kalman filter with the given plant.
047   *
048   * <p>See <a
049   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
050   * for how to select the standard deviations.
051   *
052   * @param states A Nat representing the states of the system.
053   * @param outputs A Nat representing the outputs of the system.
054   * @param plant The plant used for the prediction step.
055   * @param stateStdDevs Standard deviations of model states.
056   * @param measurementStdDevs Standard deviations of measurements.
057   * @param dtSeconds Nominal discretization timestep.
058   * @throws IllegalArgumentException If the system is unobservable.
059   */
060  public KalmanFilter(
061      Nat<States> states,
062      Nat<Outputs> outputs,
063      LinearSystem<States, Inputs, Outputs> plant,
064      Matrix<States, N1> stateStdDevs,
065      Matrix<Outputs, N1> measurementStdDevs,
066      double dtSeconds) {
067    this.m_states = states;
068
069    this.m_plant = plant;
070
071    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
072    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
073    m_dtSeconds = dtSeconds;
074
075    // Find discrete A and Q
076    var pair = Discretization.discretizeAQ(plant.getA(), m_contQ, dtSeconds);
077    var discA = pair.getFirst();
078    var discQ = pair.getSecond();
079
080    var discR = Discretization.discretizeR(m_contR, dtSeconds);
081
082    var C = plant.getC();
083
084    if (!StateSpaceUtil.isDetectable(discA, C)) {
085      var msg =
086          "The system passed to the Kalman filter is unobservable!\n\nA =\n"
087              + discA.getStorage().toString()
088              + "\nC =\n"
089              + C.getStorage().toString()
090              + '\n';
091      MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace());
092      throw new IllegalArgumentException(msg);
093    }
094
095    m_initP = new Matrix<>(DARE.dare(discA.transpose(), C.transpose(), discQ, discR));
096
097    reset();
098  }
099
100  /**
101   * Returns the error covariance matrix P.
102   *
103   * @return the error covariance matrix P.
104   */
105  @Override
106  public Matrix<States, States> getP() {
107    return m_P;
108  }
109
110  /**
111   * Returns an element of the error covariance matrix P.
112   *
113   * @param row Row of P.
114   * @param col Column of P.
115   * @return the value of the error covariance matrix P at (i, j).
116   */
117  @Override
118  public double getP(int row, int col) {
119    return m_P.get(row, col);
120  }
121
122  /**
123   * Sets the entire error covariance matrix P.
124   *
125   * @param newP The new value of P to use.
126   */
127  @Override
128  public void setP(Matrix<States, States> newP) {
129    m_P = newP;
130  }
131
132  /**
133   * Returns the state estimate x-hat.
134   *
135   * @return the state estimate x-hat.
136   */
137  @Override
138  public Matrix<States, N1> getXhat() {
139    return m_xHat;
140  }
141
142  /**
143   * Returns an element of the state estimate x-hat.
144   *
145   * @param row Row of x-hat.
146   * @return the value of the state estimate x-hat at that row.
147   */
148  @Override
149  public double getXhat(int row) {
150    return m_xHat.get(row, 0);
151  }
152
153  /**
154   * Set initial state estimate x-hat.
155   *
156   * @param xHat The state estimate x-hat.
157   */
158  @Override
159  public void setXhat(Matrix<States, N1> xHat) {
160    m_xHat = xHat;
161  }
162
163  /**
164   * Set an element of the initial state estimate x-hat.
165   *
166   * @param row Row of x-hat.
167   * @param value Value for element of x-hat.
168   */
169  @Override
170  public void setXhat(int row, double value) {
171    m_xHat.set(row, 0, value);
172  }
173
174  @Override
175  public final void reset() {
176    m_xHat = new Matrix<>(m_states, Nat.N1());
177    m_P = m_initP;
178  }
179
180  /**
181   * Project the model into the future with a new control input u.
182   *
183   * @param u New control input from controller.
184   * @param dtSeconds Timestep for prediction.
185   */
186  @Override
187  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
188    // Find discrete A and Q
189    final var discPair = Discretization.discretizeAQ(m_plant.getA(), m_contQ, dtSeconds);
190    final var discA = discPair.getFirst();
191    final var discQ = discPair.getSecond();
192
193    m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
194
195    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
196    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
197
198    m_dtSeconds = dtSeconds;
199  }
200
201  /**
202   * Correct the state estimate x-hat using the measurements in y.
203   *
204   * @param u Same control input used in the predict step.
205   * @param y Measurement vector.
206   */
207  @Override
208  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
209    correct(u, y, m_contR);
210  }
211
212  /**
213   * Correct the state estimate x-hat using the measurements in y.
214   *
215   * <p>This is useful for when the measurement noise covariances vary.
216   *
217   * @param u Same control input used in the predict step.
218   * @param y Measurement vector.
219   * @param R Continuous measurement noise covariance matrix.
220   */
221  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
222    final var C = m_plant.getC();
223    final var D = m_plant.getD();
224
225    final var discR = Discretization.discretizeR(R, m_dtSeconds);
226
227    final var S = C.times(m_P).times(C.transpose()).plus(discR);
228
229    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
230    // efficiently.
231    //
232    // K = PCᵀS⁻¹
233    // KS = PCᵀ
234    // (KS)ᵀ = (PCᵀ)ᵀ
235    // SᵀKᵀ = CPᵀ
236    //
237    // The solution of Ax = b can be found via x = A.solve(b).
238    //
239    // Kᵀ = Sᵀ.solve(CPᵀ)
240    // K = (Sᵀ.solve(CPᵀ))ᵀ
241    final Matrix<States, Outputs> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
242
243    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
244    m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
245
246    // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
247    // Use Joseph form for numerical stability
248    m_P =
249        Matrix.eye(m_states)
250            .minus(K.times(C))
251            .times(m_P)
252            .times(Matrix.eye(m_states).minus(K.times(C)).transpose())
253            .plus(K.times(discR).times(K.transpose()));
254  }
255}