001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.Pair; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM; 018import org.ejml.simple.SimpleMatrix; 019 020/** 021 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 022 * true system state. This is useful because many states cannot be measured directly as a result of 023 * sensor noise, or because the state is "hidden". 024 * 025 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 026 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 027 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 028 * amount of the difference between the actual measurements and the measurements predicted by the 029 * model. 030 * 031 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the 032 * error covariance using sigma points chosen to approximate the true probability distribution. 033 * 034 * <p>For more on the underlying math, read <a 035 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a> 036 * chapter 9 "Stochastic control theory". 037 * 038 * <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more 039 * information about the SR-UKF, see <a 040 * href="https://www.researchgate.net/publication/3908304">https://www.researchgate.net/publication/3908304</a>. 041 */ 042public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 043 implements KalmanTypeFilter<States, Inputs, Outputs> { 044 private final Nat<States> m_states; 045 private final Nat<Outputs> m_outputs; 046 047 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 048 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 049 050 private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX; 051 private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY; 052 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX; 053 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 054 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 055 056 private Matrix<States, N1> m_xHat; 057 private Matrix<States, States> m_S; 058 private final Matrix<States, States> m_contQ; 059 private final Matrix<Outputs, Outputs> m_contR; 060 private Matrix<States, ?> m_sigmasF; 061 private double m_dtSeconds; 062 063 private final MerweScaledSigmaPoints<States> m_pts; 064 065 /** 066 * Constructs an Unscented Kalman Filter. 067 * 068 * <p>See <a 069 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 070 * for how to select the standard deviations. 071 * 072 * @param states A Nat representing the number of states. 073 * @param outputs A Nat representing the number of outputs. 074 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 075 * @param h A vector-valued function of x and u that returns the measurement vector. 076 * @param stateStdDevs Standard deviations of model states. 077 * @param measurementStdDevs Standard deviations of measurements. 078 * @param nominalDtSeconds Nominal discretization timestep. 079 */ 080 public UnscentedKalmanFilter( 081 Nat<States> states, 082 Nat<Outputs> outputs, 083 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 084 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 085 Matrix<States, N1> stateStdDevs, 086 Matrix<Outputs, N1> measurementStdDevs, 087 double nominalDtSeconds) { 088 this( 089 states, 090 outputs, 091 f, 092 h, 093 stateStdDevs, 094 measurementStdDevs, 095 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 096 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 097 Matrix::minus, 098 Matrix::minus, 099 Matrix::plus, 100 nominalDtSeconds); 101 } 102 103 /** 104 * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using 105 * custom functions for arithmetic can be useful if you have angles in the state or measurements, 106 * because they allow you to correctly account for the modular nature of angle arithmetic. 107 * 108 * <p>See <a 109 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 110 * for how to select the standard deviations. 111 * 112 * @param states A Nat representing the number of states. 113 * @param outputs A Nat representing the number of outputs. 114 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 115 * @param h A vector-valued function of x and u that returns the measurement vector. 116 * @param stateStdDevs Standard deviations of model states. 117 * @param measurementStdDevs Standard deviations of measurements. 118 * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a 119 * given set of weights. 120 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 121 * a given set of weights. 122 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 123 * subtracts them.) 124 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 125 * subtracts them.) 126 * @param addFuncX A function that adds two state vectors. 127 * @param nominalDtSeconds Nominal discretization timestep. 128 */ 129 public UnscentedKalmanFilter( 130 Nat<States> states, 131 Nat<Outputs> outputs, 132 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 133 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 134 Matrix<States, N1> stateStdDevs, 135 Matrix<Outputs, N1> measurementStdDevs, 136 BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX, 137 BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY, 138 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 139 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 140 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 141 double nominalDtSeconds) { 142 this.m_states = states; 143 this.m_outputs = outputs; 144 145 m_f = f; 146 m_h = h; 147 148 m_meanFuncX = meanFuncX; 149 m_meanFuncY = meanFuncY; 150 m_residualFuncX = residualFuncX; 151 m_residualFuncY = residualFuncY; 152 m_addFuncX = addFuncX; 153 154 m_dtSeconds = nominalDtSeconds; 155 156 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 157 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 158 159 m_pts = new MerweScaledSigmaPoints<>(states); 160 161 reset(); 162 } 163 164 static <S extends Num, C extends Num> 165 Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform( 166 Nat<S> s, 167 Nat<C> dim, 168 Matrix<C, ?> sigmas, 169 Matrix<?, N1> Wm, 170 Matrix<?, N1> Wc, 171 BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc, 172 BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc, 173 Matrix<C, C> squareRootR) { 174 if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) { 175 throw new IllegalArgumentException( 176 "Sigmas must be covDim by 2 * states + 1! Got " 177 + sigmas.getNumRows() 178 + " by " 179 + sigmas.getNumCols()); 180 } 181 182 if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) { 183 throw new IllegalArgumentException( 184 "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols()); 185 } 186 187 if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) { 188 throw new IllegalArgumentException( 189 "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols()); 190 } 191 192 // New mean is usually just the sum of the sigmas * weight: 193 // n 194 // dot = Σ W[k] Xᵢ[k] 195 // k=1 196 Matrix<C, N1> x = meanFunc.apply(sigmas, Wm); 197 198 Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum())); 199 for (int i = 0; i < 2 * s.getNum(); i++) { 200 Sbar.setColumn( 201 i, 202 residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0)))); 203 } 204 Sbar.assignBlock(0, 2 * s.getNum(), squareRootR); 205 206 QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM(); 207 var qrStorage = Sbar.transpose().getStorage(); 208 209 if (!qr.decompose(qrStorage.getDDRM())) { 210 throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage); 211 } 212 213 Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true))); 214 newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false); 215 216 return new Pair<>(x, newS); 217 } 218 219 /** 220 * Returns the square-root error covariance matrix S. 221 * 222 * @return the square-root error covariance matrix S. 223 */ 224 public Matrix<States, States> getS() { 225 return m_S; 226 } 227 228 /** 229 * Returns an element of the square-root error covariance matrix S. 230 * 231 * @param row Row of S. 232 * @param col Column of S. 233 * @return the value of the square-root error covariance matrix S at (i, j). 234 */ 235 public double getS(int row, int col) { 236 return m_S.get(row, col); 237 } 238 239 /** 240 * Sets the entire square-root error covariance matrix S. 241 * 242 * @param newS The new value of S to use. 243 */ 244 public void setS(Matrix<States, States> newS) { 245 m_S = newS; 246 } 247 248 /** 249 * Returns the reconstructed error covariance matrix P. 250 * 251 * @return the error covariance matrix P. 252 */ 253 @Override 254 public Matrix<States, States> getP() { 255 return m_S.transpose().times(m_S); 256 } 257 258 /** 259 * Returns an element of the error covariance matrix P. 260 * 261 * @param row Row of P. 262 * @param col Column of P. 263 * @return the value of the error covariance matrix P at (i, j). 264 * @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported 265 */ 266 @Override 267 public double getP(int row, int col) { 268 throw new UnsupportedOperationException( 269 "indexing into the reconstructed P matrix is not supported"); 270 } 271 272 /** 273 * Sets the entire error covariance matrix P. 274 * 275 * @param newP The new value of P to use. 276 */ 277 @Override 278 public void setP(Matrix<States, States> newP) { 279 m_S = newP.lltDecompose(false); 280 } 281 282 /** 283 * Returns the state estimate x-hat. 284 * 285 * @return the state estimate x-hat. 286 */ 287 @Override 288 public Matrix<States, N1> getXhat() { 289 return m_xHat; 290 } 291 292 /** 293 * Returns an element of the state estimate x-hat. 294 * 295 * @param row Row of x-hat. 296 * @return the value of the state estimate x-hat at 'i'. 297 */ 298 @Override 299 public double getXhat(int row) { 300 return m_xHat.get(row, 0); 301 } 302 303 /** 304 * Set initial state estimate x-hat. 305 * 306 * @param xHat The state estimate x-hat. 307 */ 308 @Override 309 public void setXhat(Matrix<States, N1> xHat) { 310 m_xHat = xHat; 311 } 312 313 /** 314 * Set an element of the initial state estimate x-hat. 315 * 316 * @param row Row of x-hat. 317 * @param value Value for element of x-hat. 318 */ 319 @Override 320 public void setXhat(int row, double value) { 321 m_xHat.set(row, 0, value); 322 } 323 324 /** Resets the observer. */ 325 @Override 326 public final void reset() { 327 m_xHat = new Matrix<>(m_states, Nat.N1()); 328 m_S = new Matrix<>(m_states, m_states); 329 m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); 330 } 331 332 /** 333 * Project the model into the future with a new control input u. 334 * 335 * @param u New control input from controller. 336 * @param dtSeconds Timestep for prediction. 337 */ 338 @Override 339 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 340 // Discretize Q before projecting mean and covariance forward 341 Matrix<States, States> contA = 342 NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u); 343 var discQ = Discretization.discretizeAQ(contA, m_contQ, dtSeconds).getSecond(); 344 var squareRootDiscQ = discQ.lltDecompose(true); 345 346 var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S); 347 348 for (int i = 0; i < m_pts.getNumSigmas(); ++i) { 349 Matrix<States, N1> x = sigmas.extractColumnVector(i); 350 351 m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds)); 352 } 353 354 var ret = 355 squareRootUnscentedTransform( 356 m_states, 357 m_states, 358 m_sigmasF, 359 m_pts.getWm(), 360 m_pts.getWc(), 361 m_meanFuncX, 362 m_residualFuncX, 363 squareRootDiscQ); 364 365 m_xHat = ret.getFirst(); 366 m_S = ret.getSecond(); 367 m_dtSeconds = dtSeconds; 368 } 369 370 /** 371 * Correct the state estimate x-hat using the measurements in y. 372 * 373 * @param u Same control input used in the predict step. 374 * @param y Measurement vector. 375 */ 376 @Override 377 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 378 correct( 379 m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX); 380 } 381 382 /** 383 * Correct the state estimate x-hat using the measurements in y. 384 * 385 * <p>This is useful for when the measurement noise covariances vary. 386 * 387 * @param u Same control input used in the predict step. 388 * @param y Measurement vector. 389 * @param R Continuous measurement noise covariance matrix. 390 */ 391 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) { 392 correct(m_outputs, u, y, m_h, R, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX); 393 } 394 395 /** 396 * Correct the state estimate x-hat using the measurements in y. 397 * 398 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 399 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 400 * of this function). 401 * 402 * @param <R> Number of measurements in y. 403 * @param rows Number of rows in y. 404 * @param u Same control input used in the predict step. 405 * @param y Measurement vector. 406 * @param h A vector-valued function of x and u that returns the measurement vector. 407 * @param R Continuous measurement noise covariance matrix. 408 */ 409 public <R extends Num> void correct( 410 Nat<R> rows, 411 Matrix<Inputs, N1> u, 412 Matrix<R, N1> y, 413 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 414 Matrix<R, R> R) { 415 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY = 416 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)); 417 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX = 418 Matrix::minus; 419 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus; 420 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus; 421 correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX); 422 } 423 424 /** 425 * Correct the state estimate x-hat using the measurements in y. 426 * 427 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 428 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 429 * of this function). 430 * 431 * @param <R> Number of measurements in y. 432 * @param rows Number of rows in y. 433 * @param u Same control input used in the predict step. 434 * @param y Measurement vector. 435 * @param h A vector-valued function of x and u that returns the measurement vector. 436 * @param R Continuous measurement noise covariance matrix. 437 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 438 * a given set of weights. 439 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 440 * subtracts them.) 441 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 442 * subtracts them.) 443 * @param addFuncX A function that adds two state vectors. 444 */ 445 public <R extends Num> void correct( 446 Nat<R> rows, 447 Matrix<Inputs, N1> u, 448 Matrix<R, N1> y, 449 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 450 Matrix<R, R> R, 451 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY, 452 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY, 453 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 454 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 455 final var discR = Discretization.discretizeR(R, m_dtSeconds); 456 final var squareRootDiscR = discR.lltDecompose(true); 457 458 // Transform sigma points into measurement space 459 Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1)); 460 var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S); 461 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 462 Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u); 463 sigmasH.setColumn(i, hRet); 464 } 465 466 // Mean and covariance of prediction passed through unscented transform 467 var transRet = 468 squareRootUnscentedTransform( 469 m_states, 470 rows, 471 sigmasH, 472 m_pts.getWm(), 473 m_pts.getWc(), 474 meanFuncY, 475 residualFuncY, 476 squareRootDiscR); 477 var yHat = transRet.getFirst(); 478 var Sy = transRet.getSecond(); 479 480 // Compute cross covariance of the state and the measurements 481 Matrix<States, R> Pxy = new Matrix<>(m_states, rows); 482 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 483 // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i] 484 var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat); 485 var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose(); 486 487 Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i))); 488 } 489 490 // K = (P_{xy} / S_yᵀ) / S_y 491 // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y 492 // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ 493 Matrix<States, R> K = 494 Sy.transpose() 495 .solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose())) 496 .transpose(); 497 498 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ) 499 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat))); 500 501 Matrix<States, R> U = K.times(Sy); 502 for (int i = 0; i < rows.getNum(); i++) { 503 m_S.rankUpdate(U.extractColumnVector(i), -1, false); 504 } 505 } 506}