001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.DARE; 008import edu.wpi.first.math.Matrix; 009import edu.wpi.first.math.Nat; 010import edu.wpi.first.math.Num; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017 018/** 019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 020 * true system state. This is useful because many states cannot be measured directly as a result of 021 * sensor noise, or because the state is "hidden". 022 * 023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 026 * amount of the difference between the actual measurements and the measurements predicted by the 027 * model. 028 * 029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the 030 * error covariance by linearizing the models around the state estimate, then applying the linear 031 * Kalman filter equations. 032 * 033 * <p>For more on the underlying math, read <a 034 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a> 035 * chapter 9 "Stochastic control theory". 036 */ 037public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 038 implements KalmanTypeFilter<States, Inputs, Outputs> { 039 private final Nat<States> m_states; 040 private final Nat<Outputs> m_outputs; 041 042 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 043 044 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 045 046 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 047 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 048 049 private final Matrix<States, States> m_contQ; 050 private final Matrix<States, States> m_initP; 051 private final Matrix<Outputs, Outputs> m_contR; 052 053 private Matrix<States, N1> m_xHat; 054 055 private Matrix<States, States> m_P; 056 057 private double m_dtSeconds; 058 059 /** 060 * Constructs an extended Kalman filter. 061 * 062 * <p>See <a 063 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 064 * for how to select the standard deviations. 065 * 066 * @param states a Nat representing the number of states. 067 * @param inputs a Nat representing the number of inputs. 068 * @param outputs a Nat representing the number of outputs. 069 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 070 * @param h A vector-valued function of x and u that returns the measurement vector. 071 * @param stateStdDevs Standard deviations of model states. 072 * @param measurementStdDevs Standard deviations of measurements. 073 * @param dtSeconds Nominal discretization timestep. 074 */ 075 public ExtendedKalmanFilter( 076 Nat<States> states, 077 Nat<Inputs> inputs, 078 Nat<Outputs> outputs, 079 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 080 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 081 Matrix<States, N1> stateStdDevs, 082 Matrix<Outputs, N1> measurementStdDevs, 083 double dtSeconds) { 084 this( 085 states, 086 inputs, 087 outputs, 088 f, 089 h, 090 stateStdDevs, 091 measurementStdDevs, 092 Matrix::minus, 093 Matrix::plus, 094 dtSeconds); 095 } 096 097 /** 098 * Constructs an extended Kalman filter. 099 * 100 * <p>See <a 101 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 102 * for how to select the standard deviations. 103 * 104 * @param states a Nat representing the number of states. 105 * @param inputs a Nat representing the number of inputs. 106 * @param outputs a Nat representing the number of outputs. 107 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 108 * @param h A vector-valued function of x and u that returns the measurement vector. 109 * @param stateStdDevs Standard deviations of model states. 110 * @param measurementStdDevs Standard deviations of measurements. 111 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 112 * subtracts them.) 113 * @param addFuncX A function that adds two state vectors. 114 * @param dtSeconds Nominal discretization timestep. 115 */ 116 public ExtendedKalmanFilter( 117 Nat<States> states, 118 Nat<Inputs> inputs, 119 Nat<Outputs> outputs, 120 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 121 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 122 Matrix<States, N1> stateStdDevs, 123 Matrix<Outputs, N1> measurementStdDevs, 124 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 125 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 126 double dtSeconds) { 127 m_states = states; 128 m_outputs = outputs; 129 130 m_f = f; 131 m_h = h; 132 133 m_residualFuncY = residualFuncY; 134 m_addFuncX = addFuncX; 135 136 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 137 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 138 m_dtSeconds = dtSeconds; 139 140 reset(); 141 142 final var contA = 143 NumericalJacobian.numericalJacobianX( 144 states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1())); 145 final var C = 146 NumericalJacobian.numericalJacobianX( 147 outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1())); 148 149 final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds); 150 final var discA = discPair.getFirst(); 151 final var discQ = discPair.getSecond(); 152 153 final var discR = Discretization.discretizeR(m_contR, dtSeconds); 154 155 if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) { 156 m_initP = DARE.dare(discA.transpose(), C.transpose(), discQ, discR); 157 } else { 158 m_initP = new Matrix<>(states, states); 159 } 160 161 m_P = m_initP; 162 } 163 164 /** 165 * Returns the error covariance matrix P. 166 * 167 * @return the error covariance matrix P. 168 */ 169 @Override 170 public Matrix<States, States> getP() { 171 return m_P; 172 } 173 174 /** 175 * Returns an element of the error covariance matrix P. 176 * 177 * @param row Row of P. 178 * @param col Column of P. 179 * @return the value of the error covariance matrix P at (i, j). 180 */ 181 @Override 182 public double getP(int row, int col) { 183 return m_P.get(row, col); 184 } 185 186 /** 187 * Sets the entire error covariance matrix P. 188 * 189 * @param newP The new value of P to use. 190 */ 191 @Override 192 public void setP(Matrix<States, States> newP) { 193 m_P = newP; 194 } 195 196 /** 197 * Returns the state estimate x-hat. 198 * 199 * @return the state estimate x-hat. 200 */ 201 @Override 202 public Matrix<States, N1> getXhat() { 203 return m_xHat; 204 } 205 206 /** 207 * Returns an element of the state estimate x-hat. 208 * 209 * @param row Row of x-hat. 210 * @return the value of the state estimate x-hat at that row. 211 */ 212 @Override 213 public double getXhat(int row) { 214 return m_xHat.get(row, 0); 215 } 216 217 /** 218 * Set initial state estimate x-hat. 219 * 220 * @param xHat The state estimate x-hat. 221 */ 222 @Override 223 public void setXhat(Matrix<States, N1> xHat) { 224 m_xHat = xHat; 225 } 226 227 /** 228 * Set an element of the initial state estimate x-hat. 229 * 230 * @param row Row of x-hat. 231 * @param value Value for element of x-hat. 232 */ 233 @Override 234 public void setXhat(int row, double value) { 235 m_xHat.set(row, 0, value); 236 } 237 238 @Override 239 public final void reset() { 240 m_xHat = new Matrix<>(m_states, Nat.N1()); 241 m_P = m_initP; 242 } 243 244 /** 245 * Project the model into the future with a new control input u. 246 * 247 * @param u New control input from controller. 248 * @param dtSeconds Timestep for prediction. 249 */ 250 @Override 251 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 252 predict(u, m_f, dtSeconds); 253 } 254 255 /** 256 * Project the model into the future with a new control input u. 257 * 258 * @param u New control input from controller. 259 * @param f The function used to linearize the model. 260 * @param dtSeconds Timestep for prediction. 261 */ 262 public void predict( 263 Matrix<Inputs, N1> u, 264 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 265 double dtSeconds) { 266 // Find continuous A 267 final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u); 268 269 // Find discrete A and Q 270 final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds); 271 final var discA = discPair.getFirst(); 272 final var discQ = discPair.getSecond(); 273 274 m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds); 275 276 // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q 277 m_P = discA.times(m_P).times(discA.transpose()).plus(discQ); 278 279 m_dtSeconds = dtSeconds; 280 } 281 282 /** 283 * Correct the state estimate x-hat using the measurements in y. 284 * 285 * @param u Same control input used in the predict step. 286 * @param y Measurement vector. 287 */ 288 @Override 289 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 290 correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX); 291 } 292 293 /** 294 * Correct the state estimate x-hat using the measurements in y. 295 * 296 * <p>This is useful for when the measurement noise covariances vary. 297 * 298 * @param u Same control input used in the predict step. 299 * @param y Measurement vector. 300 * @param R Continuous measurement noise covariance matrix. 301 */ 302 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) { 303 correct(m_outputs, u, y, m_h, R, m_residualFuncY, m_addFuncX); 304 } 305 306 /** 307 * Correct the state estimate x-hat using the measurements in y. 308 * 309 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 310 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 311 * of this function). 312 * 313 * @param <Rows> Number of rows in the result of f(x, u). 314 * @param rows Number of rows in the result of f(x, u). 315 * @param u Same control input used in the predict step. 316 * @param y Measurement vector. 317 * @param h A vector-valued function of x and u that returns the measurement vector. 318 * @param R Continuous measurement noise covariance matrix. 319 */ 320 public <Rows extends Num> void correct( 321 Nat<Rows> rows, 322 Matrix<Inputs, N1> u, 323 Matrix<Rows, N1> y, 324 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 325 Matrix<Rows, Rows> R) { 326 correct(rows, u, y, h, R, Matrix::minus, Matrix::plus); 327 } 328 329 /** 330 * Correct the state estimate x-hat using the measurements in y. 331 * 332 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 333 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 334 * of this function). 335 * 336 * @param <Rows> Number of rows in the result of f(x, u). 337 * @param rows Number of rows in the result of f(x, u). 338 * @param u Same control input used in the predict step. 339 * @param y Measurement vector. 340 * @param h A vector-valued function of x and u that returns the measurement vector. 341 * @param R Continuous measurement noise covariance matrix. 342 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 343 * subtracts them.) 344 * @param addFuncX A function that adds two state vectors. 345 */ 346 public <Rows extends Num> void correct( 347 Nat<Rows> rows, 348 Matrix<Inputs, N1> u, 349 Matrix<Rows, N1> y, 350 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 351 Matrix<Rows, Rows> R, 352 BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY, 353 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 354 final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u); 355 final var discR = Discretization.discretizeR(R, m_dtSeconds); 356 357 final var S = C.times(m_P).times(C.transpose()).plus(discR); 358 359 // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more 360 // efficiently. 361 // 362 // K = PCᵀS⁻¹ 363 // KS = PCᵀ 364 // (KS)ᵀ = (PCᵀ)ᵀ 365 // SᵀKᵀ = CPᵀ 366 // 367 // The solution of Ax = b can be found via x = A.solve(b). 368 // 369 // Kᵀ = Sᵀ.solve(CPᵀ) 370 // K = (Sᵀ.solve(CPᵀ))ᵀ 371 final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose(); 372 373 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁)) 374 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u)))); 375 376 // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ 377 // Use Joseph form for numerical stability 378 m_P = 379 Matrix.eye(m_states) 380 .minus(K.times(C)) 381 .times(m_P) 382 .times(Matrix.eye(m_states).minus(K.times(C)).transpose()) 383 .plus(K.times(discR).times(K.transpose())); 384 } 385}