001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.Pair; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM; 018import org.ejml.simple.SimpleMatrix; 019 020/** 021 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 022 * true system state. This is useful because many states cannot be measured directly as a result of 023 * sensor noise, or because the state is "hidden". 024 * 025 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 026 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 027 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 028 * amount of the difference between the actual measurements and the measurements predicted by the 029 * model. 030 * 031 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the 032 * error covariance using sigma points chosen to approximate the true probability distribution. 033 * 034 * <p>For more on the underlying math, read <a 035 * href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a> 036 * chapter 9 "Stochastic control theory". 037 * 038 * <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more 039 * information about the SR-UKF, see <a 040 * href="https://www.researchgate.net/publication/3908304">https://www.researchgate.net/publication/3908304</a>. 041 * 042 * @param <States> Number of states. 043 * @param <Inputs> Number of inputs. 044 * @param <Outputs> Number of outputs. 045 */ 046public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 047 implements KalmanTypeFilter<States, Inputs, Outputs> { 048 private final Nat<States> m_states; 049 private final Nat<Outputs> m_outputs; 050 051 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 052 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 053 054 private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX; 055 private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY; 056 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX; 057 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 058 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 059 060 private Matrix<States, N1> m_xHat; 061 private Matrix<States, States> m_S; 062 private final Matrix<States, States> m_contQ; 063 private final Matrix<Outputs, Outputs> m_contR; 064 private Matrix<States, ?> m_sigmasF; 065 private double m_dtSeconds; 066 067 private final MerweScaledSigmaPoints<States> m_pts; 068 069 /** 070 * Constructs an Unscented Kalman Filter. 071 * 072 * <p>See <a 073 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 074 * for how to select the standard deviations. 075 * 076 * @param states A Nat representing the number of states. 077 * @param outputs A Nat representing the number of outputs. 078 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 079 * @param h A vector-valued function of x and u that returns the measurement vector. 080 * @param stateStdDevs Standard deviations of model states. 081 * @param measurementStdDevs Standard deviations of measurements. 082 * @param nominalDtSeconds Nominal discretization timestep. 083 */ 084 public UnscentedKalmanFilter( 085 Nat<States> states, 086 Nat<Outputs> outputs, 087 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 088 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 089 Matrix<States, N1> stateStdDevs, 090 Matrix<Outputs, N1> measurementStdDevs, 091 double nominalDtSeconds) { 092 this( 093 states, 094 outputs, 095 f, 096 h, 097 stateStdDevs, 098 measurementStdDevs, 099 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 100 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 101 Matrix::minus, 102 Matrix::minus, 103 Matrix::plus, 104 nominalDtSeconds); 105 } 106 107 /** 108 * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using 109 * custom functions for arithmetic can be useful if you have angles in the state or measurements, 110 * because they allow you to correctly account for the modular nature of angle arithmetic. 111 * 112 * <p>See <a 113 * href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a> 114 * for how to select the standard deviations. 115 * 116 * @param states A Nat representing the number of states. 117 * @param outputs A Nat representing the number of outputs. 118 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 119 * @param h A vector-valued function of x and u that returns the measurement vector. 120 * @param stateStdDevs Standard deviations of model states. 121 * @param measurementStdDevs Standard deviations of measurements. 122 * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a 123 * given set of weights. 124 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 125 * a given set of weights. 126 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 127 * subtracts them.) 128 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 129 * subtracts them.) 130 * @param addFuncX A function that adds two state vectors. 131 * @param nominalDtSeconds Nominal discretization timestep. 132 */ 133 public UnscentedKalmanFilter( 134 Nat<States> states, 135 Nat<Outputs> outputs, 136 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 137 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 138 Matrix<States, N1> stateStdDevs, 139 Matrix<Outputs, N1> measurementStdDevs, 140 BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX, 141 BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY, 142 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 143 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 144 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 145 double nominalDtSeconds) { 146 this.m_states = states; 147 this.m_outputs = outputs; 148 149 m_f = f; 150 m_h = h; 151 152 m_meanFuncX = meanFuncX; 153 m_meanFuncY = meanFuncY; 154 m_residualFuncX = residualFuncX; 155 m_residualFuncY = residualFuncY; 156 m_addFuncX = addFuncX; 157 158 m_dtSeconds = nominalDtSeconds; 159 160 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 161 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 162 163 m_pts = new MerweScaledSigmaPoints<>(states); 164 165 reset(); 166 } 167 168 static <S extends Num, C extends Num> 169 Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform( 170 Nat<S> s, 171 Nat<C> dim, 172 Matrix<C, ?> sigmas, 173 Matrix<?, N1> Wm, 174 Matrix<?, N1> Wc, 175 BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc, 176 BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc, 177 Matrix<C, C> squareRootR) { 178 if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) { 179 throw new IllegalArgumentException( 180 "Sigmas must be covDim by 2 * states + 1! Got " 181 + sigmas.getNumRows() 182 + " by " 183 + sigmas.getNumCols()); 184 } 185 186 if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) { 187 throw new IllegalArgumentException( 188 "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols()); 189 } 190 191 if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) { 192 throw new IllegalArgumentException( 193 "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols()); 194 } 195 196 // New mean is usually just the sum of the sigmas * weight: 197 // n 198 // dot = Σ W[k] Xᵢ[k] 199 // k=1 200 Matrix<C, N1> x = meanFunc.apply(sigmas, Wm); 201 202 Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum())); 203 for (int i = 0; i < 2 * s.getNum(); i++) { 204 Sbar.setColumn( 205 i, 206 residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0)))); 207 } 208 Sbar.assignBlock(0, 2 * s.getNum(), squareRootR); 209 210 QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM(); 211 var qrStorage = Sbar.transpose().getStorage(); 212 213 if (!qr.decompose(qrStorage.getDDRM())) { 214 throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage); 215 } 216 217 Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true))); 218 newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false); 219 220 return new Pair<>(x, newS); 221 } 222 223 /** 224 * Returns the square-root error covariance matrix S. 225 * 226 * @return the square-root error covariance matrix S. 227 */ 228 public Matrix<States, States> getS() { 229 return m_S; 230 } 231 232 /** 233 * Returns an element of the square-root error covariance matrix S. 234 * 235 * @param row Row of S. 236 * @param col Column of S. 237 * @return the value of the square-root error covariance matrix S at (i, j). 238 */ 239 public double getS(int row, int col) { 240 return m_S.get(row, col); 241 } 242 243 /** 244 * Sets the entire square-root error covariance matrix S. 245 * 246 * @param newS The new value of S to use. 247 */ 248 public void setS(Matrix<States, States> newS) { 249 m_S = newS; 250 } 251 252 /** 253 * Returns the reconstructed error covariance matrix P. 254 * 255 * @return the error covariance matrix P. 256 */ 257 @Override 258 public Matrix<States, States> getP() { 259 return m_S.transpose().times(m_S); 260 } 261 262 /** 263 * Returns an element of the error covariance matrix P. 264 * 265 * @param row Row of P. 266 * @param col Column of P. 267 * @return the value of the error covariance matrix P at (i, j). 268 * @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported 269 */ 270 @Override 271 public double getP(int row, int col) { 272 throw new UnsupportedOperationException( 273 "indexing into the reconstructed P matrix is not supported"); 274 } 275 276 /** 277 * Sets the entire error covariance matrix P. 278 * 279 * @param newP The new value of P to use. 280 */ 281 @Override 282 public void setP(Matrix<States, States> newP) { 283 m_S = newP.lltDecompose(false); 284 } 285 286 /** 287 * Returns the state estimate x-hat. 288 * 289 * @return the state estimate x-hat. 290 */ 291 @Override 292 public Matrix<States, N1> getXhat() { 293 return m_xHat; 294 } 295 296 /** 297 * Returns an element of the state estimate x-hat. 298 * 299 * @param row Row of x-hat. 300 * @return the value of the state estimate x-hat at 'i'. 301 */ 302 @Override 303 public double getXhat(int row) { 304 return m_xHat.get(row, 0); 305 } 306 307 /** 308 * Set initial state estimate x-hat. 309 * 310 * @param xHat The state estimate x-hat. 311 */ 312 @Override 313 public void setXhat(Matrix<States, N1> xHat) { 314 m_xHat = xHat; 315 } 316 317 /** 318 * Set an element of the initial state estimate x-hat. 319 * 320 * @param row Row of x-hat. 321 * @param value Value for element of x-hat. 322 */ 323 @Override 324 public void setXhat(int row, double value) { 325 m_xHat.set(row, 0, value); 326 } 327 328 /** Resets the observer. */ 329 @Override 330 public final void reset() { 331 m_xHat = new Matrix<>(m_states, Nat.N1()); 332 m_S = new Matrix<>(m_states, m_states); 333 m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); 334 } 335 336 /** 337 * Project the model into the future with a new control input u. 338 * 339 * @param u New control input from controller. 340 * @param dtSeconds Timestep for prediction. 341 */ 342 @Override 343 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 344 // Discretize Q before projecting mean and covariance forward 345 Matrix<States, States> contA = 346 NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u); 347 var discQ = Discretization.discretizeAQ(contA, m_contQ, dtSeconds).getSecond(); 348 var squareRootDiscQ = discQ.lltDecompose(true); 349 350 var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S); 351 352 for (int i = 0; i < m_pts.getNumSigmas(); ++i) { 353 Matrix<States, N1> x = sigmas.extractColumnVector(i); 354 355 m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds)); 356 } 357 358 var ret = 359 squareRootUnscentedTransform( 360 m_states, 361 m_states, 362 m_sigmasF, 363 m_pts.getWm(), 364 m_pts.getWc(), 365 m_meanFuncX, 366 m_residualFuncX, 367 squareRootDiscQ); 368 369 m_xHat = ret.getFirst(); 370 m_S = ret.getSecond(); 371 m_dtSeconds = dtSeconds; 372 } 373 374 /** 375 * Correct the state estimate x-hat using the measurements in y. 376 * 377 * @param u Same control input used in the predict step. 378 * @param y Measurement vector. 379 */ 380 @Override 381 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 382 correct( 383 m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX); 384 } 385 386 /** 387 * Correct the state estimate x-hat using the measurements in y. 388 * 389 * <p>This is useful for when the measurement noise covariances vary. 390 * 391 * @param u Same control input used in the predict step. 392 * @param y Measurement vector. 393 * @param R Continuous measurement noise covariance matrix. 394 */ 395 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y, Matrix<Outputs, Outputs> R) { 396 correct(m_outputs, u, y, m_h, R, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX); 397 } 398 399 /** 400 * Correct the state estimate x-hat using the measurements in y. 401 * 402 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 403 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 404 * of this function). 405 * 406 * @param <R> Number of measurements in y. 407 * @param rows Number of rows in y. 408 * @param u Same control input used in the predict step. 409 * @param y Measurement vector. 410 * @param h A vector-valued function of x and u that returns the measurement vector. 411 * @param R Continuous measurement noise covariance matrix. 412 */ 413 public <R extends Num> void correct( 414 Nat<R> rows, 415 Matrix<Inputs, N1> u, 416 Matrix<R, N1> y, 417 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 418 Matrix<R, R> R) { 419 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY = 420 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)); 421 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX = 422 Matrix::minus; 423 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus; 424 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus; 425 correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX); 426 } 427 428 /** 429 * Correct the state estimate x-hat using the measurements in y. 430 * 431 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 432 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 433 * of this function). 434 * 435 * @param <R> Number of measurements in y. 436 * @param rows Number of rows in y. 437 * @param u Same control input used in the predict step. 438 * @param y Measurement vector. 439 * @param h A vector-valued function of x and u that returns the measurement vector. 440 * @param R Continuous measurement noise covariance matrix. 441 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 442 * a given set of weights. 443 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 444 * subtracts them.) 445 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 446 * subtracts them.) 447 * @param addFuncX A function that adds two state vectors. 448 */ 449 public <R extends Num> void correct( 450 Nat<R> rows, 451 Matrix<Inputs, N1> u, 452 Matrix<R, N1> y, 453 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 454 Matrix<R, R> R, 455 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY, 456 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY, 457 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 458 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 459 final var discR = Discretization.discretizeR(R, m_dtSeconds); 460 final var squareRootDiscR = discR.lltDecompose(true); 461 462 // Transform sigma points into measurement space 463 Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1)); 464 var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S); 465 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 466 Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u); 467 sigmasH.setColumn(i, hRet); 468 } 469 470 // Mean and covariance of prediction passed through unscented transform 471 var transRet = 472 squareRootUnscentedTransform( 473 m_states, 474 rows, 475 sigmasH, 476 m_pts.getWm(), 477 m_pts.getWc(), 478 meanFuncY, 479 residualFuncY, 480 squareRootDiscR); 481 var yHat = transRet.getFirst(); 482 var Sy = transRet.getSecond(); 483 484 // Compute cross covariance of the state and the measurements 485 Matrix<States, R> Pxy = new Matrix<>(m_states, rows); 486 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 487 // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i] 488 var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat); 489 var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose(); 490 491 Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i))); 492 } 493 494 // K = (P_{xy} / S_yᵀ) / S_y 495 // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y 496 // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ 497 Matrix<States, R> K = 498 Sy.transpose() 499 .solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose())) 500 .transpose(); 501 502 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ) 503 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat))); 504 505 Matrix<States, R> U = K.times(Sy); 506 for (int i = 0; i < rows.getNum(); i++) { 507 m_S.rankUpdate(U.extractColumnVector(i), -1, false); 508 } 509 } 510}