001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.DARE;
008import edu.wpi.first.math.Matrix;
009import edu.wpi.first.math.Nat;
010import edu.wpi.first.math.Num;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.LinearSystem;
015
016/**
017 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
018 * true system state. This is useful because many states cannot be measured directly as a result of
019 * sensor noise, or because the state is "hidden".
020 *
021 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
022 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
023 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
024 * amount of the difference between the actual measurements and the measurements predicted by the
025 * model.
026 *
027 * <p>For more on the underlying math, read <a
028 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
029 * chapter 9 "Stochastic control theory".
030 *
031 * @param <States> Number of states.
032 * @param <Inputs> Number of inputs.
033 * @param <Outputs> Number of outputs.
034 */
035public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
036    implements KalmanTypeFilter<States, Inputs, Outputs> {
037  private final Nat<States> m_states;
038
039  private final LinearSystem<States, Inputs, Outputs> m_plant;
040  private Matrix<States, N1> m_xHat;
041  private Matrix<States, States> m_P;
042  private final Matrix<States, States> m_contQ;
043  private final Matrix<Outputs, Outputs> m_contR;
044  private double m_dtSeconds;
045
046  private final Matrix<States, States> m_initP;
047
048  /**
049   * Constructs a Kalman filter with the given plant.
050   *
051   * <p>See <a
052   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
053   * for how to select the standard deviations.
054   *
055   * @param states A Nat representing the states of the system.
056   * @param outputs A Nat representing the outputs of the system.
057   * @param plant The plant used for the prediction step.
058   * @param stateStdDevs Standard deviations of model states.
059   * @param measurementStdDevs Standard deviations of measurements.
060   * @param dtSeconds Nominal discretization timestep.
061   * @throws IllegalArgumentException If the system is undetectable.
062   */
063  public KalmanFilter(
064      Nat<States> states,
065      Nat<Outputs> outputs,
066      LinearSystem<States, Inputs, Outputs> plant,
067      Matrix<States, N1> stateStdDevs,
068      Matrix<Outputs, N1> measurementStdDevs,
069      double dtSeconds) {
070    this.m_states = states;
071
072    this.m_plant = plant;
073
074    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
075    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
076    m_dtSeconds = dtSeconds;
077
078    // Find discrete A and Q
079    var pair = Discretization.discretizeAQ(plant.getA(), m_contQ, dtSeconds);
080    var discA = pair.getFirst();
081    var discQ = pair.getSecond();
082
083    var discR = Discretization.discretizeR(m_contR, dtSeconds);
084
085    var C = plant.getC();
086
087    m_initP = new Matrix<>(DARE.dare(discA.transpose(), C.transpose(), discQ, discR));
088
089    reset();
090  }
091
092  /**
093   * Returns the error covariance matrix P.
094   *
095   * @return the error covariance matrix P.
096   */
097  @Override
098  public Matrix<States, States> getP() {
099    return m_P;
100  }
101
102  /**
103   * Returns an element of the error covariance matrix P.
104   *
105   * @param row Row of P.
106   * @param col Column of P.
107   * @return the value of the error covariance matrix P at (i, j).
108   */
109  @Override
110  public double getP(int row, int col) {
111    return m_P.get(row, col);
112  }
113
114  /**
115   * Sets the entire error covariance matrix P.
116   *
117   * @param newP The new value of P to use.
118   */
119  @Override
120  public void setP(Matrix<States, States> newP) {
121    m_P = newP;
122  }
123
124  /**
125   * Returns the state estimate x-hat.
126   *
127   * @return the state estimate x-hat.
128   */
129  @Override
130  public Matrix<States, N1> getXhat() {
131    return m_xHat;
132  }
133
134  /**
135   * Returns an element of the state estimate x-hat.
136   *
137   * @param row Row of x-hat.
138   * @return the value of the state estimate x-hat at that row.
139   */
140  @Override
141  public double getXhat(int row) {
142    return m_xHat.get(row, 0);
143  }
144
145  /**
146   * Set initial state estimate x-hat.
147   *
148   * @param xHat The state estimate x-hat.
149   */
150  @Override
151  public void setXhat(Matrix<States, N1> xHat) {
152    m_xHat = xHat;
153  }
154
155  /**
156   * Set an element of the initial state estimate x-hat.
157   *
158   * @param row Row of x-hat.
159   * @param value Value for element of x-hat.
160   */
161  @Override
162  public void setXhat(int row, double value) {
163    m_xHat.set(row, 0, value);
164  }
165
166  @Override
167  public final void reset() {
168    m_xHat = new Matrix<>(m_states, Nat.N1());
169    m_P = m_initP;
170  }
171
172  /**
173   * Project the model into the future with a new control input u.
174   *
175   * @param u New control input from controller.
176   * @param dtSeconds Timestep for prediction.
177   */
178  @Override
179  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
180    // Find discrete A and Q
181    final var discPair = Discretization.discretizeAQ(m_plant.getA(), m_contQ, dtSeconds);
182    final var discA = discPair.getFirst();
183    final var discQ = discPair.getSecond();
184
185    m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
186
187    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
188    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
189
190    m_dtSeconds = dtSeconds;
191  }
192
193  /**
194   * Correct the state estimate x-hat using the measurements in y.
195   *
196   * @param u Same control input used in the predict step.
197   * @param y Measurement vector.
198   */
199  @Override
200  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
201    correct(u, y, m_contR);
202  }
203
204  /**
205   * Correct the state estimate x-hat using the measurements in y.
206   *
207   * <p>This is useful for when the measurement noise covariances vary.
208   *
209   * @param u Same control input used in the predict step.
210   * @param y Measurement vector.
211   * @param R Continuous measurement noise covariance matrix.
212   */
213  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
214    final var C = m_plant.getC();
215    final var D = m_plant.getD();
216
217    final var discR = Discretization.discretizeR(R, m_dtSeconds);
218
219    final var S = C.times(m_P).times(C.transpose()).plus(discR);
220
221    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
222    // efficiently.
223    //
224    // K = PCᵀS⁻¹
225    // KS = PCᵀ
226    // (KS)ᵀ = (PCᵀ)ᵀ
227    // SᵀKᵀ = CPᵀ
228    //
229    // The solution of Ax = b can be found via x = A.solve(b).
230    //
231    // Kᵀ = Sᵀ.solve(CPᵀ)
232    // K = (Sᵀ.solve(CPᵀ))ᵀ
233    final Matrix<States, Outputs> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
234
235    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
236    m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
237
238    // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
239    // Use Joseph form for numerical stability
240    m_P =
241        Matrix.eye(m_states)
242            .minus(K.times(C))
243            .times(m_P)
244            .times(Matrix.eye(m_states).minus(K.times(C)).transpose())
245            .plus(K.times(discR).times(K.transpose()));
246  }
247}