WPILibC++ 2025.2.1
Loading...
Searching...
No Matches
UnscentedKalmanFilter.h
Go to the documentation of this file.
1// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#pragma once
6
7#include <functional>
8#include <utility>
9
10#include <Eigen/Cholesky>
11#include <wpi/SymbolExports.h>
12#include <wpi/array.h>
13
14#include "frc/EigenCore.h"
15#include "frc/StateSpaceUtil.h"
21#include "units/time.h"
22
23namespace frc {
24
25/**
26 * A Kalman filter combines predictions from a model and measurements to give an
27 * estimate of the true system state. This is useful because many states cannot
28 * be measured directly as a result of sensor noise, or because the state is
29 * "hidden".
30 *
31 * Kalman filters use a K gain matrix to determine whether to trust the model or
32 * measurements more. Kalman filter theory uses statistics to compute an optimal
33 * K gain which minimizes the sum of squares error in the state estimate. This K
34 * gain is used to correct the state estimate by some amount of the difference
35 * between the actual measurements and the measurements predicted by the model.
36 *
37 * An unscented Kalman filter uses nonlinear state and measurement models. It
38 * propagates the error covariance using sigma points chosen to approximate the
39 * true probability distribution.
40 *
41 * For more on the underlying math, read
42 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9
43 * "Stochastic control theory".
44 *
45 * <p> This class implements a square-root-form unscented Kalman filter
46 * (SR-UKF). For more information about the SR-UKF, see
47 * https://www.researchgate.net/publication/3908304.
48 *
49 * @tparam States Number of states.
50 * @tparam Inputs Number of inputs.
51 * @tparam Outputs Number of outputs.
52 */
53template <int States, int Inputs, int Outputs>
55 public:
59
62
64
65 /**
66 * Constructs an unscented Kalman filter.
67 *
68 * See
69 * https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices
70 * for how to select the standard deviations.
71 *
72 * @param f A vector-valued function of x and u that returns
73 * the derivative of the state vector.
74 * @param h A vector-valued function of x and u that returns
75 * the measurement vector.
76 * @param stateStdDevs Standard deviations of model states.
77 * @param measurementStdDevs Standard deviations of measurements.
78 * @param dt Nominal discretization timestep.
79 */
81 std::function<StateVector(const StateVector&, const InputVector&)> f,
82 std::function<OutputVector(const StateVector&, const InputVector&)> h,
83 const StateArray& stateStdDevs, const OutputArray& measurementStdDevs,
84 units::second_t dt)
85 : m_f(std::move(f)), m_h(std::move(h)) {
86 m_contQ = MakeCovMatrix(stateStdDevs);
87 m_contR = MakeCovMatrix(measurementStdDevs);
88 m_meanFuncX = [](const Matrixd<States, 2 * States + 1>& sigmas,
90 return sigmas * Wm;
91 };
92 m_meanFuncY = [](const Matrixd<Outputs, 2 * States + 1>& sigmas,
94 return sigmas * Wc;
95 };
96 m_residualFuncX = [](const StateVector& a,
97 const StateVector& b) -> StateVector { return a - b; };
98 m_residualFuncY = [](const OutputVector& a,
99 const OutputVector& b) -> OutputVector {
100 return a - b;
101 };
102 m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector {
103 return a + b;
104 };
105 m_dt = dt;
106
107 Reset();
108 }
109
110 /**
111 * Constructs an unscented Kalman filter with custom mean, residual, and
112 * addition functions. Using custom functions for arithmetic can be useful if
113 * you have angles in the state or measurements, because they allow you to
114 * correctly account for the modular nature of angle arithmetic.
115 *
116 * See
117 * https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices
118 * for how to select the standard deviations.
119 *
120 * @param f A vector-valued function of x and u that returns
121 * the derivative of the state vector.
122 * @param h A vector-valued function of x and u that returns
123 * the measurement vector.
124 * @param stateStdDevs Standard deviations of model states.
125 * @param measurementStdDevs Standard deviations of measurements.
126 * @param meanFuncX A function that computes the mean of 2 * States +
127 * 1 state vectors using a given set of weights.
128 * @param meanFuncY A function that computes the mean of 2 * States +
129 * 1 measurement vectors using a given set of
130 * weights.
131 * @param residualFuncX A function that computes the residual of two
132 * state vectors (i.e. it subtracts them.)
133 * @param residualFuncY A function that computes the residual of two
134 * measurement vectors (i.e. it subtracts them.)
135 * @param addFuncX A function that adds two state vectors.
136 * @param dt Nominal discretization timestep.
137 */
139 std::function<StateVector(const StateVector&, const InputVector&)> f,
140 std::function<OutputVector(const StateVector&, const InputVector&)> h,
141 const StateArray& stateStdDevs, const OutputArray& measurementStdDevs,
142 std::function<StateVector(const Matrixd<States, 2 * States + 1>&,
144 meanFuncX,
147 meanFuncY,
148 std::function<StateVector(const StateVector&, const StateVector&)>
149 residualFuncX,
150 std::function<OutputVector(const OutputVector&, const OutputVector&)>
151 residualFuncY,
152 std::function<StateVector(const StateVector&, const StateVector&)>
153 addFuncX,
154 units::second_t dt)
155 : m_f(std::move(f)),
156 m_h(std::move(h)),
157 m_meanFuncX(std::move(meanFuncX)),
158 m_meanFuncY(std::move(meanFuncY)),
159 m_residualFuncX(std::move(residualFuncX)),
160 m_residualFuncY(std::move(residualFuncY)),
161 m_addFuncX(std::move(addFuncX)) {
162 m_contQ = MakeCovMatrix(stateStdDevs);
163 m_contR = MakeCovMatrix(measurementStdDevs);
164 m_dt = dt;
165
166 Reset();
167 }
168
169 /**
170 * Returns the square-root error covariance matrix S.
171 */
172 const StateMatrix& S() const { return m_S; }
173
174 /**
175 * Returns an element of the square-root error covariance matrix S.
176 *
177 * @param i Row of S.
178 * @param j Column of S.
179 */
180 double S(int i, int j) const { return m_S(i, j); }
181
182 /**
183 * Set the current square-root error covariance matrix S.
184 *
185 * @param S The square-root error covariance matrix S.
186 */
187 void SetS(const StateMatrix& S) { m_S = S; }
188
189 /**
190 * Returns the reconstructed error covariance matrix P.
191 */
192 StateMatrix P() const { return m_S.transpose() * m_S; }
193
194 /**
195 * Set the current square-root error covariance matrix S by taking the square
196 * root of P.
197 *
198 * @param P The error covariance matrix P.
199 */
200 void SetP(const StateMatrix& P) { m_S = P.llt().matrixU(); }
201
202 /**
203 * Returns the state estimate x-hat.
204 */
205 const StateVector& Xhat() const { return m_xHat; }
206
207 /**
208 * Returns an element of the state estimate x-hat.
209 *
210 * @param i Row of x-hat.
211 */
212 double Xhat(int i) const { return m_xHat(i); }
213
214 /**
215 * Set initial state estimate x-hat.
216 *
217 * @param xHat The state estimate x-hat.
218 */
219 void SetXhat(const StateVector& xHat) { m_xHat = xHat; }
220
221 /**
222 * Set an element of the initial state estimate x-hat.
223 *
224 * @param i Row of x-hat.
225 * @param value Value for element of x-hat.
226 */
227 void SetXhat(int i, double value) { m_xHat(i) = value; }
228
229 /**
230 * Resets the observer.
231 */
232 void Reset() {
233 m_xHat.setZero();
234 m_S.setZero();
235 m_sigmasF.setZero();
236 }
237
238 /**
239 * Project the model into the future with a new control input u.
240 *
241 * @param u New control input from controller.
242 * @param dt Timestep for prediction.
243 */
244 void Predict(const InputVector& u, units::second_t dt) {
245 m_dt = dt;
246
247 // Discretize Q before projecting mean and covariance forward
248 StateMatrix contA =
250 StateMatrix discA;
251 StateMatrix discQ;
252 DiscretizeAQ<States>(contA, m_contQ, m_dt, &discA, &discQ);
253 Eigen::internal::llt_inplace<double, Eigen::Lower>::blocked(discQ);
254
256 m_pts.SquareRootSigmaPoints(m_xHat, m_S);
257
258 for (int i = 0; i < m_pts.NumSigmas(); ++i) {
259 StateVector x = sigmas.template block<States, 1>(0, i);
260 m_sigmasF.template block<States, 1>(0, i) = RK4(m_f, x, u, dt);
261 }
262
264 m_sigmasF, m_pts.Wm(), m_pts.Wc(), m_meanFuncX, m_residualFuncX,
265 discQ.template triangularView<Eigen::Lower>());
266 m_xHat = xHat;
267 m_S = S;
268 }
269
270 /**
271 * Correct the state estimate x-hat using the measurements in y.
272 *
273 * @param u Same control input used in the predict step.
274 * @param y Measurement vector.
275 */
276 void Correct(const InputVector& u, const OutputVector& y) {
277 Correct<Outputs>(u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY,
278 m_residualFuncX, m_addFuncX);
279 }
280
281 /**
282 * Correct the state estimate x-hat using the measurements in y.
283 *
284 * This is useful for when the measurement noise covariances vary.
285 *
286 * @param u Same control input used in the predict step.
287 * @param y Measurement vector.
288 * @param R Continuous measurement noise covariance matrix.
289 */
290 void Correct(const InputVector& u, const OutputVector& y,
291 const Matrixd<Outputs, Outputs>& R) {
292 Correct<Outputs>(u, y, m_h, R, m_meanFuncY, m_residualFuncY,
293 m_residualFuncX, m_addFuncX);
294 }
295
296 /**
297 * Correct the state estimate x-hat using the measurements in y.
298 *
299 * This is useful for when the measurements available during a timestep's
300 * Correct() call vary. The h(x, u) passed to the constructor is used if one
301 * is not provided (the two-argument version of this function).
302 *
303 * @param u Same control input used in the predict step.
304 * @param y Measurement vector.
305 * @param h A vector-valued function of x and u that returns the measurement
306 * vector.
307 * @param R Continuous measurement noise covariance matrix.
308 */
309 template <int Rows>
311 const InputVector& u, const Vectord<Rows>& y,
312 std::function<Vectord<Rows>(const StateVector&, const InputVector&)> h,
313 const Matrixd<Rows, Rows>& R) {
314 auto meanFuncY = [](const Matrixd<Outputs, 2 * States + 1>& sigmas,
316 return sigmas * Wc;
317 };
318 auto residualFuncX = [](const StateVector& a,
319 const StateVector& b) -> StateVector {
320 return a - b;
321 };
322 auto residualFuncY = [](const Vectord<Rows>& a,
323 const Vectord<Rows>& b) -> Vectord<Rows> {
324 return a - b;
325 };
326 auto addFuncX = [](const StateVector& a,
327 const StateVector& b) -> StateVector { return a + b; };
328 Correct<Rows>(u, y, std::move(h), R, std::move(meanFuncY),
329 std::move(residualFuncY), std::move(residualFuncX),
330 std::move(addFuncX));
331 }
332
333 /**
334 * Correct the state estimate x-hat using the measurements in y.
335 *
336 * This is useful for when the measurements available during a timestep's
337 * Correct() call vary. The h(x, u) passed to the constructor is used if one
338 * is not provided (the two-argument version of this function).
339 *
340 * @param u Same control input used in the predict step.
341 * @param y Measurement vector.
342 * @param h A vector-valued function of x and u that returns the
343 * measurement vector.
344 * @param R Continuous measurement noise covariance matrix.
345 * @param meanFuncY A function that computes the mean of 2 * States + 1
346 * measurement vectors using a given set of weights.
347 * @param residualFuncY A function that computes the residual of two
348 * measurement vectors (i.e. it subtracts them.)
349 * @param residualFuncX A function that computes the residual of two state
350 * vectors (i.e. it subtracts them.)
351 * @param addFuncX A function that adds two state vectors.
352 */
353 template <int Rows>
355 const InputVector& u, const Vectord<Rows>& y,
356 std::function<Vectord<Rows>(const StateVector&, const InputVector&)> h,
357 const Matrixd<Rows, Rows>& R,
358 std::function<Vectord<Rows>(const Matrixd<Rows, 2 * States + 1>&,
360 meanFuncY,
361 std::function<Vectord<Rows>(const Vectord<Rows>&, const Vectord<Rows>&)>
362 residualFuncY,
363 std::function<StateVector(const StateVector&, const StateVector&)>
364 residualFuncX,
365 std::function<StateVector(const StateVector&, const StateVector&)>
366 addFuncX) {
367 Matrixd<Rows, Rows> discR = DiscretizeR<Rows>(R, m_dt);
368 Eigen::internal::llt_inplace<double, Eigen::Lower>::blocked(discR);
369
370 // Transform sigma points into measurement space
373 m_pts.SquareRootSigmaPoints(m_xHat, m_S);
374 for (int i = 0; i < m_pts.NumSigmas(); ++i) {
375 sigmasH.template block<Rows, 1>(0, i) =
376 h(sigmas.template block<States, 1>(0, i), u);
377 }
378
379 // Mean and covariance of prediction passed through UT
381 sigmasH, m_pts.Wm(), m_pts.Wc(), meanFuncY, residualFuncY,
382 discR.template triangularView<Eigen::Lower>());
383
384 // Compute cross covariance of the state and the measurements
386 Pxy.setZero();
387 for (int i = 0; i < m_pts.NumSigmas(); ++i) {
388 // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
389 Pxy +=
390 m_pts.Wc(i) *
391 (residualFuncX(m_sigmasF.template block<States, 1>(0, i), m_xHat)) *
392 (residualFuncY(sigmasH.template block<Rows, 1>(0, i), yHat))
393 .transpose();
394 }
395
396 // K = (P_{xy} / S_yᵀ) / S_y
397 // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
398 // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
400 Sy.transpose()
401 .fullPivHouseholderQr()
402 .solve(Sy.fullPivHouseholderQr().solve(Pxy.transpose()))
403 .transpose();
404
405 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
406 m_xHat = addFuncX(m_xHat, K * residualFuncY(y, yHat));
407
408 Matrixd<States, Rows> U = K * Sy;
409 for (int i = 0; i < Rows; i++) {
410 Eigen::internal::llt_inplace<double, Eigen::Upper>::rankUpdate(
411 m_S, U.template block<States, 1>(0, i), -1);
412 }
413 }
414
415 private:
416 std::function<StateVector(const StateVector&, const InputVector&)> m_f;
417 std::function<OutputVector(const StateVector&, const InputVector&)> m_h;
418 std::function<StateVector(const Matrixd<States, 2 * States + 1>&,
420 m_meanFuncX;
423 m_meanFuncY;
424 std::function<StateVector(const StateVector&, const StateVector&)>
425 m_residualFuncX;
426 std::function<OutputVector(const OutputVector&, const OutputVector&)>
427 m_residualFuncY;
428 std::function<StateVector(const StateVector&, const StateVector&)> m_addFuncX;
429 StateVector m_xHat;
430 StateMatrix m_S;
431 StateMatrix m_contQ;
434 units::second_t m_dt;
435
437};
438
439extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT)
440 UnscentedKalmanFilter<3, 3, 1>;
441extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT)
442 UnscentedKalmanFilter<5, 3, 3>;
443
444} // namespace frc
#define WPILIB_DLLEXPORT
Definition SymbolExports.h:36
#define EXPORT_TEMPLATE_DECLARE(export)
Definition SymbolExports.hpp:92
Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the Unscente...
Definition MerweScaledSigmaPoints.h:28
A Kalman filter combines predictions from a model and measurements to give an estimate of the true sy...
Definition UnscentedKalmanFilter.h:54
const StateVector & Xhat() const
Returns the state estimate x-hat.
Definition UnscentedKalmanFilter.h:205
Matrixd< States, States > StateMatrix
Definition UnscentedKalmanFilter.h:63
const StateMatrix & S() const
Returns the square-root error covariance matrix S.
Definition UnscentedKalmanFilter.h:172
Vectord< Outputs > OutputVector
Definition UnscentedKalmanFilter.h:58
double Xhat(int i) const
Returns an element of the state estimate x-hat.
Definition UnscentedKalmanFilter.h:212
void Predict(const InputVector &u, units::second_t dt)
Project the model into the future with a new control input u.
Definition UnscentedKalmanFilter.h:244
void Correct(const InputVector &u, const Vectord< Rows > &y, std::function< Vectord< Rows >(const StateVector &, const InputVector &)> h, const Matrixd< Rows, Rows > &R, std::function< Vectord< Rows >(const Matrixd< Rows, 2 *States+1 > &, const Vectord< 2 *States+1 > &)> meanFuncY, std::function< Vectord< Rows >(const Vectord< Rows > &, const Vectord< Rows > &)> residualFuncY, std::function< StateVector(const StateVector &, const StateVector &)> residualFuncX, std::function< StateVector(const StateVector &, const StateVector &)> addFuncX)
Correct the state estimate x-hat using the measurements in y.
Definition UnscentedKalmanFilter.h:354
double S(int i, int j) const
Returns an element of the square-root error covariance matrix S.
Definition UnscentedKalmanFilter.h:180
Vectord< States > StateVector
Definition UnscentedKalmanFilter.h:56
void Correct(const InputVector &u, const OutputVector &y, const Matrixd< Outputs, Outputs > &R)
Correct the state estimate x-hat using the measurements in y.
Definition UnscentedKalmanFilter.h:290
void SetS(const StateMatrix &S)
Set the current square-root error covariance matrix S.
Definition UnscentedKalmanFilter.h:187
UnscentedKalmanFilter(std::function< StateVector(const StateVector &, const InputVector &)> f, std::function< OutputVector(const StateVector &, const InputVector &)> h, const StateArray &stateStdDevs, const OutputArray &measurementStdDevs, std::function< StateVector(const Matrixd< States, 2 *States+1 > &, const Vectord< 2 *States+1 > &)> meanFuncX, std::function< OutputVector(const Matrixd< Outputs, 2 *States+1 > &, const Vectord< 2 *States+1 > &)> meanFuncY, std::function< StateVector(const StateVector &, const StateVector &)> residualFuncX, std::function< OutputVector(const OutputVector &, const OutputVector &)> residualFuncY, std::function< StateVector(const StateVector &, const StateVector &)> addFuncX, units::second_t dt)
Constructs an unscented Kalman filter with custom mean, residual, and addition functions.
Definition UnscentedKalmanFilter.h:138
StateMatrix P() const
Returns the reconstructed error covariance matrix P.
Definition UnscentedKalmanFilter.h:192
void SetXhat(const StateVector &xHat)
Set initial state estimate x-hat.
Definition UnscentedKalmanFilter.h:219
UnscentedKalmanFilter(std::function< StateVector(const StateVector &, const InputVector &)> f, std::function< OutputVector(const StateVector &, const InputVector &)> h, const StateArray &stateStdDevs, const OutputArray &measurementStdDevs, units::second_t dt)
Constructs an unscented Kalman filter.
Definition UnscentedKalmanFilter.h:80
void SetXhat(int i, double value)
Set an element of the initial state estimate x-hat.
Definition UnscentedKalmanFilter.h:227
void Correct(const InputVector &u, const Vectord< Rows > &y, std::function< Vectord< Rows >(const StateVector &, const InputVector &)> h, const Matrixd< Rows, Rows > &R)
Correct the state estimate x-hat using the measurements in y.
Definition UnscentedKalmanFilter.h:310
Vectord< Inputs > InputVector
Definition UnscentedKalmanFilter.h:57
void SetP(const StateMatrix &P)
Set the current square-root error covariance matrix S by taking the square root of P.
Definition UnscentedKalmanFilter.h:200
void Reset()
Resets the observer.
Definition UnscentedKalmanFilter.h:232
void Correct(const InputVector &u, const OutputVector &y)
Correct the state estimate x-hat using the measurements in y.
Definition UnscentedKalmanFilter.h:276
This class is a wrapper around std::array that does compile time size checking.
Definition array.h:26
Definition CAN.h:11
std::tuple< Vectord< CovDim >, Matrixd< CovDim, CovDim > > SquareRootUnscentedTransform(const Matrixd< CovDim, 2 *States+1 > &sigmas, const Vectord< 2 *States+1 > &Wm, const Vectord< 2 *States+1 > &Wc, std::function< Vectord< CovDim >(const Matrixd< CovDim, 2 *States+1 > &, const Vectord< 2 *States+1 > &)> meanFunc, std::function< Vectord< CovDim >(const Vectord< CovDim > &, const Vectord< CovDim > &)> residualFunc, const Matrixd< CovDim, CovDim > &squareRootR)
Computes unscented transform of a set of sigma points and weights.
Definition UnscentedTransform.h:40
T RK4(F &&f, T x, units::second_t dt)
Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
Definition NumericalIntegration.h:23
void DiscretizeAQ(const Matrixd< States, States > &contA, const Matrixd< States, States > &contQ, units::second_t dt, Matrixd< States, States > *discA, Matrixd< States, States > *discQ)
Discretizes the given continuous A and Q matrices.
Definition Discretization.h:71
Eigen::Matrix< double, Rows, Cols, Options, MaxRows, MaxCols > Matrixd
Definition EigenCore.h:21
Eigen::Vector< double, Size > Vectord
Definition EigenCore.h:12
Matrixd< Outputs, Outputs > DiscretizeR(const Matrixd< Outputs, Outputs > &R, units::second_t dt)
Returns a discretized version of the provided continuous measurement noise covariance matrix.
Definition Discretization.h:113
constexpr Matrixd< sizeof...(Ts), sizeof...(Ts)> MakeCovMatrix(Ts... stdDevs)
Creates a covariance matrix from the given vector for use with Kalman filters.
Definition StateSpaceUtil.h:75
auto NumericalJacobianX(F &&f, const Vectord< States > &x, const Vectord< Inputs > &u, Args &&... args)
Returns numerical Jacobian with respect to x for f(x, u, ...).
Definition NumericalJacobian.h:51
Implement std::hash so that hash_code can be used in STL containers.
Definition PointerIntPair.h:280