001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.DARE;
008import edu.wpi.first.math.MathSharedStore;
009import edu.wpi.first.math.Matrix;
010import edu.wpi.first.math.Nat;
011import edu.wpi.first.math.Num;
012import edu.wpi.first.math.StateSpaceUtil;
013import edu.wpi.first.math.numbers.N1;
014import edu.wpi.first.math.system.Discretization;
015import edu.wpi.first.math.system.LinearSystem;
016
017/**
018 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
019 * true system state. This is useful because many states cannot be measured directly as a result of
020 * sensor noise, or because the state is "hidden".
021 *
022 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
023 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
024 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
025 * amount of the difference between the actual measurements and the measurements predicted by the
026 * model.
027 *
028 * <p>For more on the underlying math, read <a
029 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
030 * chapter 9 "Stochastic control theory".
031 *
032 * @param <States> Number of states.
033 * @param <Inputs> Number of inputs.
034 * @param <Outputs> Number of outputs.
035 */
036public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
037    implements KalmanTypeFilter<States, Inputs, Outputs> {
038  private final Nat<States> m_states;
039
040  private final LinearSystem<States, Inputs, Outputs> m_plant;
041  private Matrix<States, N1> m_xHat;
042  private Matrix<States, States> m_P;
043  private final Matrix<States, States> m_contQ;
044  private final Matrix<Outputs, Outputs> m_contR;
045  private double m_dtSeconds;
046
047  private final Matrix<States, States> m_initP;
048
049  /**
050   * Constructs a Kalman filter with the given plant.
051   *
052   * <p>See <a
053   * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
054   * for how to select the standard deviations.
055   *
056   * @param states A Nat representing the states of the system.
057   * @param outputs A Nat representing the outputs of the system.
058   * @param plant The plant used for the prediction step.
059   * @param stateStdDevs Standard deviations of model states.
060   * @param measurementStdDevs Standard deviations of measurements.
061   * @param dtSeconds Nominal discretization timestep.
062   * @throws IllegalArgumentException If the system is unobservable.
063   */
064  public KalmanFilter(
065      Nat<States> states,
066      Nat<Outputs> outputs,
067      LinearSystem<States, Inputs, Outputs> plant,
068      Matrix<States, N1> stateStdDevs,
069      Matrix<Outputs, N1> measurementStdDevs,
070      double dtSeconds) {
071    this.m_states = states;
072
073    this.m_plant = plant;
074
075    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
076    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
077    m_dtSeconds = dtSeconds;
078
079    // Find discrete A and Q
080    var pair = Discretization.discretizeAQ(plant.getA(), m_contQ, dtSeconds);
081    var discA = pair.getFirst();
082    var discQ = pair.getSecond();
083
084    var discR = Discretization.discretizeR(m_contR, dtSeconds);
085
086    var C = plant.getC();
087
088    if (!StateSpaceUtil.isDetectable(discA, C)) {
089      var msg =
090          "The system passed to the Kalman filter is unobservable!\n\nA =\n"
091              + discA.getStorage().toString()
092              + "\nC =\n"
093              + C.getStorage().toString()
094              + '\n';
095      MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace());
096      throw new IllegalArgumentException(msg);
097    }
098
099    m_initP = new Matrix<>(DARE.dare(discA.transpose(), C.transpose(), discQ, discR));
100
101    reset();
102  }
103
104  /**
105   * Returns the error covariance matrix P.
106   *
107   * @return the error covariance matrix P.
108   */
109  @Override
110  public Matrix<States, States> getP() {
111    return m_P;
112  }
113
114  /**
115   * Returns an element of the error covariance matrix P.
116   *
117   * @param row Row of P.
118   * @param col Column of P.
119   * @return the value of the error covariance matrix P at (i, j).
120   */
121  @Override
122  public double getP(int row, int col) {
123    return m_P.get(row, col);
124  }
125
126  /**
127   * Sets the entire error covariance matrix P.
128   *
129   * @param newP The new value of P to use.
130   */
131  @Override
132  public void setP(Matrix<States, States> newP) {
133    m_P = newP;
134  }
135
136  /**
137   * Returns the state estimate x-hat.
138   *
139   * @return the state estimate x-hat.
140   */
141  @Override
142  public Matrix<States, N1> getXhat() {
143    return m_xHat;
144  }
145
146  /**
147   * Returns an element of the state estimate x-hat.
148   *
149   * @param row Row of x-hat.
150   * @return the value of the state estimate x-hat at that row.
151   */
152  @Override
153  public double getXhat(int row) {
154    return m_xHat.get(row, 0);
155  }
156
157  /**
158   * Set initial state estimate x-hat.
159   *
160   * @param xHat The state estimate x-hat.
161   */
162  @Override
163  public void setXhat(Matrix<States, N1> xHat) {
164    m_xHat = xHat;
165  }
166
167  /**
168   * Set an element of the initial state estimate x-hat.
169   *
170   * @param row Row of x-hat.
171   * @param value Value for element of x-hat.
172   */
173  @Override
174  public void setXhat(int row, double value) {
175    m_xHat.set(row, 0, value);
176  }
177
178  @Override
179  public final void reset() {
180    m_xHat = new Matrix<>(m_states, Nat.N1());
181    m_P = m_initP;
182  }
183
184  /**
185   * Project the model into the future with a new control input u.
186   *
187   * @param u New control input from controller.
188   * @param dtSeconds Timestep for prediction.
189   */
190  @Override
191  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
192    // Find discrete A and Q
193    final var discPair = Discretization.discretizeAQ(m_plant.getA(), m_contQ, dtSeconds);
194    final var discA = discPair.getFirst();
195    final var discQ = discPair.getSecond();
196
197    m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
198
199    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
200    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
201
202    m_dtSeconds = dtSeconds;
203  }
204
205  /**
206   * Correct the state estimate x-hat using the measurements in y.
207   *
208   * @param u Same control input used in the predict step.
209   * @param y Measurement vector.
210   */
211  @Override
212  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
213    correct(u, y, m_contR);
214  }
215
216  /**
217   * Correct the state estimate x-hat using the measurements in y.
218   *
219   * <p>This is useful for when the measurement noise covariances vary.
220   *
221   * @param u Same control input used in the predict step.
222   * @param y Measurement vector.
223   * @param R Continuous measurement noise covariance matrix.
224   */
225  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) {
226    final var C = m_plant.getC();
227    final var D = m_plant.getD();
228
229    final var discR = Discretization.discretizeR(R, m_dtSeconds);
230
231    final var S = C.times(m_P).times(C.transpose()).plus(discR);
232
233    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
234    // efficiently.
235    //
236    // K = PCᵀS⁻¹
237    // KS = PCᵀ
238    // (KS)ᵀ = (PCᵀ)ᵀ
239    // SᵀKᵀ = CPᵀ
240    //
241    // The solution of Ax = b can be found via x = A.solve(b).
242    //
243    // Kᵀ = Sᵀ.solve(CPᵀ)
244    // K = (Sᵀ.solve(CPᵀ))ᵀ
245    final Matrix<States, Outputs> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
246
247    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
248    m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
249
250    // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
251    // Use Joseph form for numerical stability
252    m_P =
253        Matrix.eye(m_states)
254            .minus(K.times(C))
255            .times(m_P)
256            .times(Matrix.eye(m_states).minus(K.times(C)).transpose())
257            .plus(K.times(discR).times(K.transpose()));
258  }
259}