WPILibC++ 2025.2.1
Loading...
Searching...
No Matches
ExtendedKalmanFilter.h
Go to the documentation of this file.
1// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#pragma once
6
7#include <functional>
8#include <string>
9#include <utility>
10
11#include <Eigen/Cholesky>
12#include <wpi/array.h>
13
14#include "frc/DARE.h"
15#include "frc/EigenCore.h"
16#include "frc/StateSpaceUtil.h"
17#include "frc/fmt/Eigen.h"
21#include "units/time.h"
22
23namespace frc {
24
25/**
26 * A Kalman filter combines predictions from a model and measurements to give an
27 * estimate of the true system state. This is useful because many states cannot
28 * be measured directly as a result of sensor noise, or because the state is
29 * "hidden".
30 *
31 * Kalman filters use a K gain matrix to determine whether to trust the model or
32 * measurements more. Kalman filter theory uses statistics to compute an optimal
33 * K gain which minimizes the sum of squares error in the state estimate. This K
34 * gain is used to correct the state estimate by some amount of the difference
35 * between the actual measurements and the measurements predicted by the model.
36 *
37 * An extended Kalman filter supports nonlinear state and measurement models. It
38 * propagates the error covariance by linearizing the models around the state
39 * estimate, then applying the linear Kalman filter equations.
40 *
41 * For more on the underlying math, read
42 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9
43 * "Stochastic control theory".
44 *
45 * @tparam States Number of states.
46 * @tparam Inputs Number of inputs.
47 * @tparam Outputs Number of outputs.
48 */
49template <int States, int Inputs, int Outputs>
51 public:
55
58
60
61 /**
62 * Constructs an extended Kalman filter.
63 *
64 * See
65 * https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices
66 * for how to select the standard deviations.
67 *
68 * @param f A vector-valued function of x and u that returns
69 * the derivative of the state vector.
70 * @param h A vector-valued function of x and u that returns
71 * the measurement vector.
72 * @param stateStdDevs Standard deviations of model states.
73 * @param measurementStdDevs Standard deviations of measurements.
74 * @param dt Nominal discretization timestep.
75 */
77 std::function<StateVector(const StateVector&, const InputVector&)> f,
78 std::function<OutputVector(const StateVector&, const InputVector&)> h,
79 const StateArray& stateStdDevs, const OutputArray& measurementStdDevs,
80 units::second_t dt)
81 : m_f(std::move(f)), m_h(std::move(h)) {
82 m_contQ = MakeCovMatrix(stateStdDevs);
83 m_contR = MakeCovMatrix(measurementStdDevs);
84 m_residualFuncY = [](const OutputVector& a,
85 const OutputVector& b) -> OutputVector {
86 return a - b;
87 };
88 m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector {
89 return a + b;
90 };
91 m_dt = dt;
92
94 m_f, m_xHat, InputVector::Zero());
96 m_h, m_xHat, InputVector::Zero());
97
98 StateMatrix discA;
99 StateMatrix discQ;
100 DiscretizeAQ<States>(contA, m_contQ, dt, &discA, &discQ);
101
103
104 if (IsDetectable<States, Outputs>(discA, C) && Outputs <= States) {
105 if (auto P = DARE<States, Outputs>(discA.transpose(), C.transpose(),
106 discQ, discR)) {
107 m_initP = P.value();
108 } else if (P.error() == DAREError::QNotSymmetric ||
110 std::string msg =
111 fmt::format("{}\n\nQ =\n{}\n", to_string(P.error()), discQ);
112
114 throw std::invalid_argument(msg);
115 } else if (P.error() == DAREError::RNotSymmetric ||
116 P.error() == DAREError::RNotPositiveDefinite) {
117 std::string msg =
118 fmt::format("{}\n\nR =\n{}\n", to_string(P.error()), discR);
119
121 throw std::invalid_argument(msg);
122 } else if (P.error() == DAREError::ABNotStabilizable) {
123 std::string msg = fmt::format(
124 "The (A, C) pair is not detectable.\n\nA =\n{}\nC =\n{}\n",
125 to_string(P.error()), discA, C);
126
128 throw std::invalid_argument(msg);
129 } else if (P.error() == DAREError::ACNotDetectable) {
130 std::string msg = fmt::format("{}\n\nA =\n{}\nQ =\n{}\n",
131 to_string(P.error()), discA, discQ);
132
134 throw std::invalid_argument(msg);
135 }
136 } else {
137 m_initP = StateMatrix::Zero();
138 }
139 m_P = m_initP;
140 }
141
142 /**
143 * Constructs an extended Kalman filter.
144 *
145 * See
146 * https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices
147 * for how to select the standard deviations.
148 *
149 * @param f A vector-valued function of x and u that returns
150 * the derivative of the state vector.
151 * @param h A vector-valued function of x and u that returns
152 * the measurement vector.
153 * @param stateStdDevs Standard deviations of model states.
154 * @param measurementStdDevs Standard deviations of measurements.
155 * @param residualFuncY A function that computes the residual of two
156 * measurement vectors (i.e. it subtracts them.)
157 * @param addFuncX A function that adds two state vectors.
158 * @param dt Nominal discretization timestep.
159 */
161 std::function<StateVector(const StateVector&, const InputVector&)> f,
162 std::function<OutputVector(const StateVector&, const InputVector&)> h,
163 const StateArray& stateStdDevs, const OutputArray& measurementStdDevs,
164 std::function<OutputVector(const OutputVector&, const OutputVector&)>
165 residualFuncY,
166 std::function<StateVector(const StateVector&, const StateVector&)>
167 addFuncX,
168 units::second_t dt)
169 : m_f(std::move(f)),
170 m_h(std::move(h)),
171 m_residualFuncY(std::move(residualFuncY)),
172 m_addFuncX(std::move(addFuncX)) {
173 m_contQ = MakeCovMatrix(stateStdDevs);
174 m_contR = MakeCovMatrix(measurementStdDevs);
175 m_dt = dt;
176
178 m_f, m_xHat, InputVector::Zero());
180 m_h, m_xHat, InputVector::Zero());
181
182 StateMatrix discA;
183 StateMatrix discQ;
184 DiscretizeAQ<States>(contA, m_contQ, dt, &discA, &discQ);
185
187
188 if (IsDetectable<States, Outputs>(discA, C) && Outputs <= States) {
189 if (auto P = DARE<States, Outputs>(discA.transpose(), C.transpose(),
190 discQ, discR)) {
191 m_initP = P.value();
192 } else if (P.error() == DAREError::QNotSymmetric ||
194 std::string msg =
195 fmt::format("{}\n\nQ =\n{}\n", to_string(P.error()), discQ);
196
198 throw std::invalid_argument(msg);
199 } else if (P.error() == DAREError::RNotSymmetric ||
200 P.error() == DAREError::RNotPositiveDefinite) {
201 std::string msg =
202 fmt::format("{}\n\nR =\n{}\n", to_string(P.error()), discR);
203
205 throw std::invalid_argument(msg);
206 } else if (P.error() == DAREError::ABNotStabilizable) {
207 std::string msg = fmt::format(
208 "The (A, C) pair is not detectable.\n\nA =\n{}\nC =\n{}\n",
209 to_string(P.error()), discA, C);
210
212 throw std::invalid_argument(msg);
213 } else if (P.error() == DAREError::ACNotDetectable) {
214 std::string msg = fmt::format("{}\n\nA =\n{}\nQ =\n{}\n",
215 to_string(P.error()), discA, discQ);
216
218 throw std::invalid_argument(msg);
219 }
220 } else {
221 m_initP = StateMatrix::Zero();
222 }
223 m_P = m_initP;
224 }
225
226 /**
227 * Returns the error covariance matrix P.
228 */
229 const StateMatrix& P() const { return m_P; }
230
231 /**
232 * Returns an element of the error covariance matrix P.
233 *
234 * @param i Row of P.
235 * @param j Column of P.
236 */
237 double P(int i, int j) const { return m_P(i, j); }
238
239 /**
240 * Set the current error covariance matrix P.
241 *
242 * @param P The error covariance matrix P.
243 */
244 void SetP(const StateMatrix& P) { m_P = P; }
245
246 /**
247 * Returns the state estimate x-hat.
248 */
249 const StateVector& Xhat() const { return m_xHat; }
250
251 /**
252 * Returns an element of the state estimate x-hat.
253 *
254 * @param i Row of x-hat.
255 */
256 double Xhat(int i) const { return m_xHat(i); }
257
258 /**
259 * Set initial state estimate x-hat.
260 *
261 * @param xHat The state estimate x-hat.
262 */
263 void SetXhat(const StateVector& xHat) { m_xHat = xHat; }
264
265 /**
266 * Set an element of the initial state estimate x-hat.
267 *
268 * @param i Row of x-hat.
269 * @param value Value for element of x-hat.
270 */
271 void SetXhat(int i, double value) { m_xHat(i) = value; }
272
273 /**
274 * Resets the observer.
275 */
276 void Reset() {
277 m_xHat.setZero();
278 m_P = m_initP;
279 }
280
281 /**
282 * Project the model into the future with a new control input u.
283 *
284 * @param u New control input from controller.
285 * @param dt Timestep for prediction.
286 */
287 void Predict(const InputVector& u, units::second_t dt) {
288 // Find continuous A
289 StateMatrix contA =
291
292 // Find discrete A and Q
293 StateMatrix discA;
294 StateMatrix discQ;
295 DiscretizeAQ<States>(contA, m_contQ, dt, &discA, &discQ);
296
297 m_xHat = RK4(m_f, m_xHat, u, dt);
298
299 // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
300 m_P = discA * m_P * discA.transpose() + discQ;
301
302 m_dt = dt;
303 }
304
305 /**
306 * Correct the state estimate x-hat using the measurements in y.
307 *
308 * @param u Same control input used in the predict step.
309 * @param y Measurement vector.
310 */
311 void Correct(const InputVector& u, const OutputVector& y) {
312 Correct<Outputs>(u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
313 }
314
315 /**
316 * Correct the state estimate x-hat using the measurements in y.
317 *
318 * This is useful for when the measurement noise covariances vary.
319 *
320 * @param u Same control input used in the predict step.
321 * @param y Measurement vector.
322 * @param R Continuous measurement noise covariance matrix.
323 */
324 void Correct(const InputVector& u, const OutputVector& y,
325 const Matrixd<Outputs, Outputs>& R) {
326 Correct<Outputs>(u, y, m_h, R, m_residualFuncY, m_addFuncX);
327 }
328
329 /**
330 * Correct the state estimate x-hat using the measurements in y.
331 *
332 * This is useful for when the measurements available during a timestep's
333 * Correct() call vary. The h(x, u) passed to the constructor is used if one
334 * is not provided (the two-argument version of this function).
335 *
336 * @param u Same control input used in the predict step.
337 * @param y Measurement vector.
338 * @param h A vector-valued function of x and u that returns the measurement
339 * vector.
340 * @param R Continuous measurement noise covariance matrix.
341 */
342 template <int Rows>
344 const InputVector& u, const Vectord<Rows>& y,
345 std::function<Vectord<Rows>(const StateVector&, const InputVector&)> h,
346 const Matrixd<Rows, Rows>& R) {
347 auto residualFuncY = [](const Vectord<Rows>& a,
348 const Vectord<Rows>& b) -> Vectord<Rows> {
349 return a - b;
350 };
351 auto addFuncX = [](const StateVector& a,
352 const StateVector& b) -> StateVector { return a + b; };
353 Correct<Rows>(u, y, std::move(h), R, std::move(residualFuncY),
354 std::move(addFuncX));
355 }
356
357 /**
358 * Correct the state estimate x-hat using the measurements in y.
359 *
360 * This is useful for when the measurements available during a timestep's
361 * Correct() call vary. The h(x, u) passed to the constructor is used if one
362 * is not provided (the two-argument version of this function).
363 *
364 * @param u Same control input used in the predict step.
365 * @param y Measurement vector.
366 * @param h A vector-valued function of x and u that returns
367 * the measurement vector.
368 * @param R Continuous measurement noise covariance matrix.
369 * @param residualFuncY A function that computes the residual of two
370 * measurement vectors (i.e. it subtracts them.)
371 * @param addFuncX A function that adds two state vectors.
372 */
373 template <int Rows>
375 const InputVector& u, const Vectord<Rows>& y,
376 std::function<Vectord<Rows>(const StateVector&, const InputVector&)> h,
377 const Matrixd<Rows, Rows>& R,
378 std::function<Vectord<Rows>(const Vectord<Rows>&, const Vectord<Rows>&)>
379 residualFuncY,
380 std::function<StateVector(const StateVector&, const StateVector&)>
381 addFuncX) {
382 const Matrixd<Rows, States> C =
384 const Matrixd<Rows, Rows> discR = DiscretizeR<Rows>(R, m_dt);
385
386 Matrixd<Rows, Rows> S = C * m_P * C.transpose() + discR;
387
388 // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
389 // efficiently.
390 //
391 // K = PCᵀS⁻¹
392 // KS = PCᵀ
393 // (KS)ᵀ = (PCᵀ)ᵀ
394 // SᵀKᵀ = CPᵀ
395 //
396 // The solution of Ax = b can be found via x = A.solve(b).
397 //
398 // Kᵀ = Sᵀ.solve(CPᵀ)
399 // K = (Sᵀ.solve(CPᵀ))ᵀ
401 S.transpose().ldlt().solve(C * m_P.transpose()).transpose();
402
403 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + Kₖ₊₁(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
404 m_xHat = addFuncX(m_xHat, K * residualFuncY(y, h(m_xHat, u)));
405
406 // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ
407 // Use Joseph form for numerical stability
408 m_P = (StateMatrix::Identity() - K * C) * m_P *
409 (StateMatrix::Identity() - K * C).transpose() +
410 K * discR * K.transpose();
411 }
412
413 private:
414 std::function<StateVector(const StateVector&, const InputVector&)> m_f;
415 std::function<OutputVector(const StateVector&, const InputVector&)> m_h;
416 std::function<OutputVector(const OutputVector&, const OutputVector&)>
417 m_residualFuncY;
418 std::function<StateVector(const StateVector&, const StateVector&)> m_addFuncX;
419 StateVector m_xHat = StateVector::Zero();
420 StateMatrix m_P;
421 StateMatrix m_contQ;
423 units::second_t m_dt;
424
425 StateMatrix m_initP;
426};
427
428} // namespace frc
A Kalman filter combines predictions from a model and measurements to give an estimate of the true sy...
Definition ExtendedKalmanFilter.h:50
void Correct(const InputVector &u, const Vectord< Rows > &y, std::function< Vectord< Rows >(const StateVector &, const InputVector &)> h, const Matrixd< Rows, Rows > &R, std::function< Vectord< Rows >(const Vectord< Rows > &, const Vectord< Rows > &)> residualFuncY, std::function< StateVector(const StateVector &, const StateVector &)> addFuncX)
Correct the state estimate x-hat using the measurements in y.
Definition ExtendedKalmanFilter.h:374
void Correct(const InputVector &u, const OutputVector &y)
Correct the state estimate x-hat using the measurements in y.
Definition ExtendedKalmanFilter.h:311
ExtendedKalmanFilter(std::function< StateVector(const StateVector &, const InputVector &)> f, std::function< OutputVector(const StateVector &, const InputVector &)> h, const StateArray &stateStdDevs, const OutputArray &measurementStdDevs, std::function< OutputVector(const OutputVector &, const OutputVector &)> residualFuncY, std::function< StateVector(const StateVector &, const StateVector &)> addFuncX, units::second_t dt)
Constructs an extended Kalman filter.
Definition ExtendedKalmanFilter.h:160
const StateMatrix & P() const
Returns the error covariance matrix P.
Definition ExtendedKalmanFilter.h:229
ExtendedKalmanFilter(std::function< StateVector(const StateVector &, const InputVector &)> f, std::function< OutputVector(const StateVector &, const InputVector &)> h, const StateArray &stateStdDevs, const OutputArray &measurementStdDevs, units::second_t dt)
Constructs an extended Kalman filter.
Definition ExtendedKalmanFilter.h:76
Vectord< Inputs > InputVector
Definition ExtendedKalmanFilter.h:53
void Reset()
Resets the observer.
Definition ExtendedKalmanFilter.h:276
void Predict(const InputVector &u, units::second_t dt)
Project the model into the future with a new control input u.
Definition ExtendedKalmanFilter.h:287
Matrixd< States, States > StateMatrix
Definition ExtendedKalmanFilter.h:59
Vectord< Outputs > OutputVector
Definition ExtendedKalmanFilter.h:54
void SetP(const StateMatrix &P)
Set the current error covariance matrix P.
Definition ExtendedKalmanFilter.h:244
void SetXhat(int i, double value)
Set an element of the initial state estimate x-hat.
Definition ExtendedKalmanFilter.h:271
const StateVector & Xhat() const
Returns the state estimate x-hat.
Definition ExtendedKalmanFilter.h:249
Vectord< States > StateVector
Definition ExtendedKalmanFilter.h:52
void SetXhat(const StateVector &xHat)
Set initial state estimate x-hat.
Definition ExtendedKalmanFilter.h:263
double P(int i, int j) const
Returns an element of the error covariance matrix P.
Definition ExtendedKalmanFilter.h:237
void Correct(const InputVector &u, const OutputVector &y, const Matrixd< Outputs, Outputs > &R)
Correct the state estimate x-hat using the measurements in y.
Definition ExtendedKalmanFilter.h:324
void Correct(const InputVector &u, const Vectord< Rows > &y, std::function< Vectord< Rows >(const StateVector &, const InputVector &)> h, const Matrixd< Rows, Rows > &R)
Correct the state estimate x-hat using the measurements in y.
Definition ExtendedKalmanFilter.h:343
double Xhat(int i) const
Returns an element of the state estimate x-hat.
Definition ExtendedKalmanFilter.h:256
This class is a wrapper around std::array that does compile time size checking.
Definition array.h:26
static void ReportError(const S &format, Args &&... args)
Definition MathShared.h:62
Definition CAN.h:11
T RK4(F &&f, T x, units::second_t dt)
Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
Definition NumericalIntegration.h:23
void DiscretizeAQ(const Matrixd< States, States > &contA, const Matrixd< States, States > &contQ, units::second_t dt, Matrixd< States, States > *discA, Matrixd< States, States > *discQ)
Discretizes the given continuous A and Q matrices.
Definition Discretization.h:71
Eigen::Matrix< double, Rows, Cols, Options, MaxRows, MaxCols > Matrixd
Definition EigenCore.h:21
Eigen::Vector< double, Size > Vectord
Definition EigenCore.h:12
Matrixd< Outputs, Outputs > DiscretizeR(const Matrixd< Outputs, Outputs > &R, units::second_t dt)
Returns a discretized version of the provided continuous measurement noise covariance matrix.
Definition Discretization.h:113
constexpr Matrixd< sizeof...(Ts), sizeof...(Ts)> MakeCovMatrix(Ts... stdDevs)
Creates a covariance matrix from the given vector for use with Kalman filters.
Definition StateSpaceUtil.h:75
wpi::expected< Eigen::Matrix< double, States, States >, DAREError > DARE(const Eigen::Matrix< double, States, States > &A, const Eigen::Matrix< double, States, Inputs > &B, const Eigen::Matrix< double, States, States > &Q, const Eigen::Matrix< double, Inputs, Inputs > &R, bool checkPreconditions=true)
Computes the unique stabilizing solution X to the discrete-time algebraic Riccati equation:
Definition DARE.h:165
constexpr std::string_view to_string(const DAREError &error)
Converts the given DAREError enum to a string.
Definition DARE.h:39
@ QNotPositiveSemidefinite
Q was not positive semidefinite.
@ RNotSymmetric
R was not symmetric.
@ QNotSymmetric
Q was not symmetric.
@ ACNotDetectable
(A, C) pair where Q = CᵀC was not detectable.
@ RNotPositiveDefinite
R was not positive definite.
@ ABNotStabilizable
(A, B) pair was not stabilizable.
bool IsDetectable(const Matrixd< States, States > &A, const Matrixd< Outputs, States > &C)
Returns true if (A, C) is a detectable pair.
Definition StateSpaceUtil.h:309
auto NumericalJacobianX(F &&f, const Vectord< States > &x, const Vectord< Inputs > &u, Args &&... args)
Returns numerical Jacobian with respect to x for f(x, u, ...).
Definition NumericalJacobian.h:51
Implement std::hash so that hash_code can be used in STL containers.
Definition PointerIntPair.h:280
#define S(label, offset, message)
Definition Errors.h:113