001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.DARE; 008import edu.wpi.first.math.Matrix; 009import edu.wpi.first.math.Nat; 010import edu.wpi.first.math.Num; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017 018/** 019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 020 * true system state. This is useful because many states cannot be measured directly as a result of 021 * sensor noise, or because the state is "hidden". 022 * 023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 026 * amount of the difference between the actual measurements and the measurements predicted by the 027 * model. 028 * 029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the 030 * error covariance by linearizing the models around the state estimate, then applying the linear 031 * Kalman filter equations. 032 * 033 * <p>For more on the underlying math, read <a 034 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a> 035 * chapter 9 "Stochastic control theory". 036 * 037 * @param <States> Number of states. 038 * @param <Inputs> Number of inputs. 039 * @param <Outputs> Number of outputs. 040 */ 041public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 042 implements KalmanTypeFilter<States, Inputs, Outputs> { 043 private final Nat<States> m_states; 044 private final Nat<Outputs> m_outputs; 045 046 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 047 048 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 049 050 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 051 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 052 053 private final Matrix<States, States> m_contQ; 054 private final Matrix<States, States> m_initP; 055 private final Matrix<Outputs, Outputs> m_contR; 056 057 private Matrix<States, N1> m_xHat; 058 059 private Matrix<States, States> m_P; 060 061 private double m_dtSeconds; 062 063 /** 064 * Constructs an extended Kalman filter. 065 * 066 * <p>See <a 067 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 068 * for how to select the standard deviations. 069 * 070 * @param states a Nat representing the number of states. 071 * @param inputs a Nat representing the number of inputs. 072 * @param outputs a Nat representing the number of outputs. 073 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 074 * @param h A vector-valued function of x and u that returns the measurement vector. 075 * @param stateStdDevs Standard deviations of model states. 076 * @param measurementStdDevs Standard deviations of measurements. 077 * @param dtSeconds Nominal discretization timestep. 078 */ 079 public ExtendedKalmanFilter( 080 Nat<States> states, 081 Nat<Inputs> inputs, 082 Nat<Outputs> outputs, 083 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 084 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 085 Matrix<States, N1> stateStdDevs, 086 Matrix<Outputs, N1> measurementStdDevs, 087 double dtSeconds) { 088 this( 089 states, 090 inputs, 091 outputs, 092 f, 093 h, 094 stateStdDevs, 095 measurementStdDevs, 096 Matrix::minus, 097 Matrix::plus, 098 dtSeconds); 099 } 100 101 /** 102 * Constructs an extended Kalman filter. 103 * 104 * <p>See <a 105 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 106 * for how to select the standard deviations. 107 * 108 * @param states a Nat representing the number of states. 109 * @param inputs a Nat representing the number of inputs. 110 * @param outputs a Nat representing the number of outputs. 111 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 112 * @param h A vector-valued function of x and u that returns the measurement vector. 113 * @param stateStdDevs Standard deviations of model states. 114 * @param measurementStdDevs Standard deviations of measurements. 115 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 116 * subtracts them.) 117 * @param addFuncX A function that adds two state vectors. 118 * @param dtSeconds Nominal discretization timestep. 119 */ 120 public ExtendedKalmanFilter( 121 Nat<States> states, 122 Nat<Inputs> inputs, 123 Nat<Outputs> outputs, 124 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 125 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 126 Matrix<States, N1> stateStdDevs, 127 Matrix<Outputs, N1> measurementStdDevs, 128 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 129 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 130 double dtSeconds) { 131 m_states = states; 132 m_outputs = outputs; 133 134 m_f = f; 135 m_h = h; 136 137 m_residualFuncY = residualFuncY; 138 m_addFuncX = addFuncX; 139 140 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 141 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 142 m_dtSeconds = dtSeconds; 143 144 reset(); 145 146 final var contA = 147 NumericalJacobian.numericalJacobianX( 148 states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1())); 149 final var C = 150 NumericalJacobian.numericalJacobianX( 151 outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1())); 152 153 final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds); 154 final var discA = discPair.getFirst(); 155 final var discQ = discPair.getSecond(); 156 157 final var discR = Discretization.discretizeR(m_contR, dtSeconds); 158 159 if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) { 160 m_initP = DARE.dare(discA.transpose(), C.transpose(), discQ, discR); 161 } else { 162 m_initP = new Matrix<>(states, states); 163 } 164 165 m_P = m_initP; 166 } 167 168 /** 169 * Returns the error covariance matrix P. 170 * 171 * @return the error covariance matrix P. 172 */ 173 @Override 174 public Matrix<States, States> getP() { 175 return m_P; 176 } 177 178 /** 179 * Returns an element of the error covariance matrix P. 180 * 181 * @param row Row of P. 182 * @param col Column of P. 183 * @return the value of the error covariance matrix P at (i, j). 184 */ 185 @Override 186 public double getP(int row, int col) { 187 return m_P.get(row, col); 188 } 189 190 /** 191 * Sets the entire error covariance matrix P. 192 * 193 * @param newP The new value of P to use. 194 */ 195 @Override 196 public void setP(Matrix<States, States> newP) { 197 m_P = newP; 198 } 199 200 /** 201 * Returns the state estimate x-hat. 202 * 203 * @return the state estimate x-hat. 204 */ 205 @Override 206 public Matrix<States, N1> getXhat() { 207 return m_xHat; 208 } 209 210 /** 211 * Returns an element of the state estimate x-hat. 212 * 213 * @param row Row of x-hat. 214 * @return the value of the state estimate x-hat at that row. 215 */ 216 @Override 217 public double getXhat(int row) { 218 return m_xHat.get(row, 0); 219 } 220 221 /** 222 * Set initial state estimate x-hat. 223 * 224 * @param xHat The state estimate x-hat. 225 */ 226 @Override 227 public void setXhat(Matrix<States, N1> xHat) { 228 m_xHat = xHat; 229 } 230 231 /** 232 * Set an element of the initial state estimate x-hat. 233 * 234 * @param row Row of x-hat. 235 * @param value Value for element of x-hat. 236 */ 237 @Override 238 public void setXhat(int row, double value) { 239 m_xHat.set(row, 0, value); 240 } 241 242 @Override 243 public final void reset() { 244 m_xHat = new Matrix<>(m_states, Nat.N1()); 245 m_P = m_initP; 246 } 247 248 /** 249 * Project the model into the future with a new control input u. 250 * 251 * @param u New control input from controller. 252 * @param dtSeconds Timestep for prediction. 253 */ 254 @Override 255 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 256 predict(u, m_f, dtSeconds); 257 } 258 259 /** 260 * Project the model into the future with a new control input u. 261 * 262 * @param u New control input from controller. 263 * @param f The function used to linearize the model. 264 * @param dtSeconds Timestep for prediction. 265 */ 266 public void predict( 267 Matrix<Inputs, N1> u, 268 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 269 double dtSeconds) { 270 // Find continuous A 271 final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u); 272 273 // Find discrete A and Q 274 final var discPair = Discretization.discretizeAQ(contA, m_contQ, dtSeconds); 275 final var discA = discPair.getFirst(); 276 final var discQ = discPair.getSecond(); 277 278 m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds); 279 280 // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q 281 m_P = discA.times(m_P).times(discA.transpose()).plus(discQ); 282 283 m_dtSeconds = dtSeconds; 284 } 285 286 /** 287 * Correct the state estimate x-hat using the measurements in y. 288 * 289 * @param u Same control input used in the predict step. 290 * @param y Measurement vector. 291 */ 292 @Override 293 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 294 correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX); 295 } 296 297 /** 298 * Correct the state estimate x-hat using the measurements in y. 299 * 300 * <p>This is useful for when the measurement noise covariances vary. 301 * 302 * @param u Same control input used in the predict step. 303 * @param y Measurement vector. 304 * @param R Continuous measurement noise covariance matrix. 305 */ 306 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) { 307 correct(m_outputs, u, y, m_h, R, m_residualFuncY, m_addFuncX); 308 } 309 310 /** 311 * Correct the state estimate x-hat using the measurements in y. 312 * 313 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 314 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 315 * of this function). 316 * 317 * @param <Rows> Number of rows in the result of f(x, u). 318 * @param rows Number of rows in the result of f(x, u). 319 * @param u Same control input used in the predict step. 320 * @param y Measurement vector. 321 * @param h A vector-valued function of x and u that returns the measurement vector. 322 * @param R Continuous measurement noise covariance matrix. 323 */ 324 public <Rows extends Num> void correct( 325 Nat<Rows> rows, 326 Matrix<Inputs, N1> u, 327 Matrix<Rows, N1> y, 328 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 329 Matrix<Rows, Rows> R) { 330 correct(rows, u, y, h, R, Matrix::minus, Matrix::plus); 331 } 332 333 /** 334 * Correct the state estimate x-hat using the measurements in y. 335 * 336 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 337 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 338 * of this function). 339 * 340 * @param <Rows> Number of rows in the result of f(x, u). 341 * @param rows Number of rows in the result of f(x, u). 342 * @param u Same control input used in the predict step. 343 * @param y Measurement vector. 344 * @param h A vector-valued function of x and u that returns the measurement vector. 345 * @param R Continuous measurement noise covariance matrix. 346 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 347 * subtracts them.) 348 * @param addFuncX A function that adds two state vectors. 349 */ 350 public <Rows extends Num> void correct( 351 Nat<Rows> rows, 352 Matrix<Inputs, N1> u, 353 Matrix<Rows, N1> y, 354 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 355 Matrix<Rows, Rows> R, 356 BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY, 357 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 358 final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u); 359 final var discR = Discretization.discretizeR(R, m_dtSeconds); 360 361 final var S = C.times(m_P).times(C.transpose()).plus(discR); 362 363 // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more 364 // efficiently. 365 // 366 // K = PCᵀS⁻¹ 367 // KS = PCᵀ 368 // (KS)ᵀ = (PCᵀ)ᵀ 369 // SᵀKᵀ = CPᵀ 370 // 371 // The solution of Ax = b can be found via x = A.solve(b). 372 // 373 // Kᵀ = Sᵀ.solve(CPᵀ) 374 // K = (Sᵀ.solve(CPᵀ))ᵀ 375 final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose(); 376 377 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁)) 378 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u)))); 379 380 // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ 381 // Use Joseph form for numerical stability 382 m_P = 383 Matrix.eye(m_states) 384 .minus(K.times(C)) 385 .times(m_P) 386 .times(Matrix.eye(m_states).minus(K.times(C)).transpose()) 387 .plus(K.times(discR).times(K.transpose())); 388 } 389}