001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.DARE; 008import edu.wpi.first.math.MathSharedStore; 009import edu.wpi.first.math.Matrix; 010import edu.wpi.first.math.Nat; 011import edu.wpi.first.math.Num; 012import edu.wpi.first.math.StateSpaceUtil; 013import edu.wpi.first.math.numbers.N1; 014import edu.wpi.first.math.system.Discretization; 015import edu.wpi.first.math.system.LinearSystem; 016 017/** 018 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 019 * true system state. This is useful because many states cannot be measured directly as a result of 020 * sensor noise, or because the state is "hidden". 021 * 022 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 023 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 024 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 025 * amount of the difference between the actual measurements and the measurements predicted by the 026 * model. 027 * 028 * <p>For more on the underlying math, read <a 029 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a> 030 * chapter 9 "Stochastic control theory". 031 * 032 * @param <States> Number of states. 033 * @param <Inputs> Number of inputs. 034 * @param <Outputs> Number of outputs. 035 */ 036public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 037 implements KalmanTypeFilter<States, Inputs, Outputs> { 038 private final Nat<States> m_states; 039 040 private final LinearSystem<States, Inputs, Outputs> m_plant; 041 private Matrix<States, N1> m_xHat; 042 private Matrix<States, States> m_P; 043 private final Matrix<States, States> m_contQ; 044 private final Matrix<Outputs, Outputs> m_contR; 045 private double m_dtSeconds; 046 047 private final Matrix<States, States> m_initP; 048 049 /** 050 * Constructs a Kalman filter with the given plant. 051 * 052 * <p>See <a 053 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 054 * for how to select the standard deviations. 055 * 056 * @param states A Nat representing the states of the system. 057 * @param outputs A Nat representing the outputs of the system. 058 * @param plant The plant used for the prediction step. 059 * @param stateStdDevs Standard deviations of model states. 060 * @param measurementStdDevs Standard deviations of measurements. 061 * @param dtSeconds Nominal discretization timestep. 062 * @throws IllegalArgumentException If the system is unobservable. 063 */ 064 public KalmanFilter( 065 Nat<States> states, 066 Nat<Outputs> outputs, 067 LinearSystem<States, Inputs, Outputs> plant, 068 Matrix<States, N1> stateStdDevs, 069 Matrix<Outputs, N1> measurementStdDevs, 070 double dtSeconds) { 071 this.m_states = states; 072 073 this.m_plant = plant; 074 075 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 076 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 077 m_dtSeconds = dtSeconds; 078 079 // Find discrete A and Q 080 var pair = Discretization.discretizeAQ(plant.getA(), m_contQ, dtSeconds); 081 var discA = pair.getFirst(); 082 var discQ = pair.getSecond(); 083 084 var discR = Discretization.discretizeR(m_contR, dtSeconds); 085 086 var C = plant.getC(); 087 088 if (!StateSpaceUtil.isDetectable(discA, C)) { 089 var msg = 090 "The system passed to the Kalman filter is unobservable!\n\nA =\n" 091 + discA.getStorage().toString() 092 + "\nC =\n" 093 + C.getStorage().toString() 094 + '\n'; 095 MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace()); 096 throw new IllegalArgumentException(msg); 097 } 098 099 m_initP = new Matrix<>(DARE.dare(discA.transpose(), C.transpose(), discQ, discR)); 100 101 reset(); 102 } 103 104 /** 105 * Returns the error covariance matrix P. 106 * 107 * @return the error covariance matrix P. 108 */ 109 @Override 110 public Matrix<States, States> getP() { 111 return m_P; 112 } 113 114 /** 115 * Returns an element of the error covariance matrix P. 116 * 117 * @param row Row of P. 118 * @param col Column of P. 119 * @return the value of the error covariance matrix P at (i, j). 120 */ 121 @Override 122 public double getP(int row, int col) { 123 return m_P.get(row, col); 124 } 125 126 /** 127 * Sets the entire error covariance matrix P. 128 * 129 * @param newP The new value of P to use. 130 */ 131 @Override 132 public void setP(Matrix<States, States> newP) { 133 m_P = newP; 134 } 135 136 /** 137 * Returns the state estimate x-hat. 138 * 139 * @return the state estimate x-hat. 140 */ 141 @Override 142 public Matrix<States, N1> getXhat() { 143 return m_xHat; 144 } 145 146 /** 147 * Returns an element of the state estimate x-hat. 148 * 149 * @param row Row of x-hat. 150 * @return the value of the state estimate x-hat at that row. 151 */ 152 @Override 153 public double getXhat(int row) { 154 return m_xHat.get(row, 0); 155 } 156 157 /** 158 * Set initial state estimate x-hat. 159 * 160 * @param xHat The state estimate x-hat. 161 */ 162 @Override 163 public void setXhat(Matrix<States, N1> xHat) { 164 m_xHat = xHat; 165 } 166 167 /** 168 * Set an element of the initial state estimate x-hat. 169 * 170 * @param row Row of x-hat. 171 * @param value Value for element of x-hat. 172 */ 173 @Override 174 public void setXhat(int row, double value) { 175 m_xHat.set(row, 0, value); 176 } 177 178 @Override 179 public final void reset() { 180 m_xHat = new Matrix<>(m_states, Nat.N1()); 181 m_P = m_initP; 182 } 183 184 /** 185 * Project the model into the future with a new control input u. 186 * 187 * @param u New control input from controller. 188 * @param dtSeconds Timestep for prediction. 189 */ 190 @Override 191 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 192 // Find discrete A and Q 193 final var discPair = Discretization.discretizeAQ(m_plant.getA(), m_contQ, dtSeconds); 194 final var discA = discPair.getFirst(); 195 final var discQ = discPair.getSecond(); 196 197 m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds); 198 199 // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q 200 m_P = discA.times(m_P).times(discA.transpose()).plus(discQ); 201 202 m_dtSeconds = dtSeconds; 203 } 204 205 /** 206 * Correct the state estimate x-hat using the measurements in y. 207 * 208 * @param u Same control input used in the predict step. 209 * @param y Measurement vector. 210 */ 211 @Override 212 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 213 correct(u, y, m_contR); 214 } 215 216 /** 217 * Correct the state estimate x-hat using the measurements in y. 218 * 219 * <p>This is useful for when the measurement noise covariances vary. 220 * 221 * @param u Same control input used in the predict step. 222 * @param y Measurement vector. 223 * @param R Continuous measurement noise covariance matrix. 224 */ 225 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) { 226 final var C = m_plant.getC(); 227 final var D = m_plant.getD(); 228 229 final var discR = Discretization.discretizeR(R, m_dtSeconds); 230 231 final var S = C.times(m_P).times(C.transpose()).plus(discR); 232 233 // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more 234 // efficiently. 235 // 236 // K = PCᵀS⁻¹ 237 // KS = PCᵀ 238 // (KS)ᵀ = (PCᵀ)ᵀ 239 // SᵀKᵀ = CPᵀ 240 // 241 // The solution of Ax = b can be found via x = A.solve(b). 242 // 243 // Kᵀ = Sᵀ.solve(CPᵀ) 244 // K = (Sᵀ.solve(CPᵀ))ᵀ 245 final Matrix<States, Outputs> K = S.transpose().solve(C.times(m_P.transpose())).transpose(); 246 247 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) 248 m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u))))); 249 250 // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ 251 // Use Joseph form for numerical stability 252 m_P = 253 Matrix.eye(m_states) 254 .minus(K.times(C)) 255 .times(m_P) 256 .times(Matrix.eye(m_states).minus(K.times(C)).transpose()) 257 .plus(K.times(discR).times(K.transpose())); 258 } 259}