WPILibC++ 2024.3.2
KalmanFilterLatencyCompensator.h
Go to the documentation of this file.
1// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#pragma once
6
7#include <algorithm>
8#include <array>
9#include <functional>
10#include <utility>
11#include <vector>
12
13#include "frc/EigenCore.h"
14#include "units/math.h"
15#include "units/time.h"
16
17namespace frc {
18
19/**
20 * This class incorporates time-delayed measurements into a Kalman filter's
21 * state estimate.
22 *
23 * @tparam States The number of states.
24 * @tparam Inputs The number of inputs.
25 * @tparam Outputs The number of outputs.
26 */
27template <int States, int Inputs, int Outputs, typename KalmanFilterType>
29 public:
30 /**
31 * This class contains all the information about our observer at a given time.
32 */
34 /// The state estimate.
36 /// The square root error covariance.
38 /// The inputs.
40 /// The local measurements.
42
43 ObserverSnapshot(const KalmanFilterType& observer, const Vectord<Inputs>& u,
44 const Vectord<Outputs>& localY)
45 : xHat(observer.Xhat()),
46 squareRootErrorCovariances(observer.S()),
47 inputs(u),
48 localMeasurements(localY) {}
49 };
50
51 /**
52 * Clears the observer snapshot buffer.
53 */
54 void Reset() { m_pastObserverSnapshots.clear(); }
55
56 /**
57 * Add past observer states to the observer snapshots list.
58 *
59 * @param observer The observer.
60 * @param u The input at the timestamp.
61 * @param localY The local output at the timestamp
62 * @param timestamp The timesnap of the state.
63 */
64 void AddObserverState(const KalmanFilterType& observer, Vectord<Inputs> u,
65 Vectord<Outputs> localY, units::second_t timestamp) {
66 // Add the new state into the vector.
67 m_pastObserverSnapshots.emplace_back(timestamp,
68 ObserverSnapshot{observer, u, localY});
69
70 // Remove the oldest snapshot if the vector exceeds our maximum size.
71 if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
72 m_pastObserverSnapshots.erase(m_pastObserverSnapshots.begin());
73 }
74 }
75
76 /**
77 * Add past global measurements (such as from vision)to the estimator.
78 *
79 * @param observer The observer to apply the past global
80 * measurement.
81 * @param nominalDt The nominal timestep.
82 * @param y The measurement.
83 * @param globalMeasurementCorrect The function take calls correct() on the
84 * observer.
85 * @param timestamp The timestamp of the measurement.
86 */
87 template <int Rows>
89 KalmanFilterType* observer, units::second_t nominalDt, Vectord<Rows> y,
90 std::function<void(const Vectord<Inputs>& u, const Vectord<Rows>& y)>
91 globalMeasurementCorrect,
92 units::second_t timestamp) {
93 if (m_pastObserverSnapshots.size() == 0) {
94 // State map was empty, which means that we got a measurement right at
95 // startup. The only thing we can do is ignore the measurement.
96 return;
97 }
98
99 // Perform a binary search to find the index of first snapshot whose
100 // timestamp is greater than or equal to the global measurement timestamp
101 auto it = std::lower_bound(
102 m_pastObserverSnapshots.cbegin(), m_pastObserverSnapshots.cend(),
103 timestamp,
104 [](const auto& entry, const auto& ts) { return entry.first < ts; });
105
106 size_t indexOfClosestEntry;
107
108 if (it == m_pastObserverSnapshots.cbegin()) {
109 // If the global measurement is older than any snapshot, throw out the
110 // measurement because there's no state estimate into which to incorporate
111 // the measurement
112 if (timestamp < it->first) {
113 return;
114 }
115
116 // If the first snapshot has same timestamp as the global measurement, use
117 // that snapshot
118 indexOfClosestEntry = 0;
119 } else if (it == m_pastObserverSnapshots.cend()) {
120 // If all snapshots are older than the global measurement, use the newest
121 // snapshot
122 indexOfClosestEntry = m_pastObserverSnapshots.size() - 1;
123 } else {
124 // Index of snapshot taken after the global measurement
125 int nextIdx = std::distance(m_pastObserverSnapshots.cbegin(), it);
126
127 // Index of snapshot taken before the global measurement. Since we already
128 // handled the case where the index points to the first snapshot, this
129 // computation is guaranteed to be nonnegative.
130 int prevIdx = nextIdx - 1;
131
132 // Find the snapshot closest in time to global measurement
133 units::second_t prevTimeDiff =
134 units::math::abs(timestamp - m_pastObserverSnapshots[prevIdx].first);
135 units::second_t nextTimeDiff =
136 units::math::abs(timestamp - m_pastObserverSnapshots[nextIdx].first);
137 indexOfClosestEntry = prevTimeDiff < nextTimeDiff ? prevIdx : nextIdx;
138 }
139
140 units::second_t lastTimestamp =
141 m_pastObserverSnapshots[indexOfClosestEntry].first - nominalDt;
142
143 // We will now go back in time to the state of the system at the time when
144 // the measurement was captured. We will reset the observer to that state,
145 // and apply correction based on the measurement. Then, we will go back
146 // through all observer states until the present and apply past inputs to
147 // get the present estimated state.
148 for (size_t i = indexOfClosestEntry; i < m_pastObserverSnapshots.size();
149 ++i) {
150 auto& [key, snapshot] = m_pastObserverSnapshots[i];
151
152 if (i == indexOfClosestEntry) {
153 observer->SetS(snapshot.squareRootErrorCovariances);
154 observer->SetXhat(snapshot.xHat);
155 }
156
157 observer->Predict(snapshot.inputs, key - lastTimestamp);
158 observer->Correct(snapshot.inputs, snapshot.localMeasurements);
159
160 if (i == indexOfClosestEntry) {
161 // Note that the measurement is at a timestep close but probably not
162 // exactly equal to the timestep for which we called predict. This makes
163 // the assumption that the dt is small enough that the difference
164 // between the measurement time and the time that the inputs were
165 // captured at is very small.
166 globalMeasurementCorrect(snapshot.inputs, y);
167 }
168
169 lastTimestamp = key;
170 snapshot = ObserverSnapshot{*observer, snapshot.inputs,
171 snapshot.localMeasurements};
172 }
173 }
174
175 private:
176 static constexpr size_t kMaxPastObserverStates = 300;
177 std::vector<std::pair<units::second_t, ObserverSnapshot>>
178 m_pastObserverSnapshots;
179};
180} // namespace frc
This class incorporates time-delayed measurements into a Kalman filter's state estimate.
Definition: KalmanFilterLatencyCompensator.h:28
void ApplyPastGlobalMeasurement(KalmanFilterType *observer, units::second_t nominalDt, Vectord< Rows > y, std::function< void(const Vectord< Inputs > &u, const Vectord< Rows > &y)> globalMeasurementCorrect, units::second_t timestamp)
Add past global measurements (such as from vision)to the estimator.
Definition: KalmanFilterLatencyCompensator.h:88
void Reset()
Clears the observer snapshot buffer.
Definition: KalmanFilterLatencyCompensator.h:54
void AddObserverState(const KalmanFilterType &observer, Vectord< Inputs > u, Vectord< Outputs > localY, units::second_t timestamp)
Add past observer states to the observer snapshots list.
Definition: KalmanFilterLatencyCompensator.h:64
UnitType abs(const UnitType x) noexcept
Compute absolute value.
Definition: math.h:721
const T & first(const T &value, const Tail &...)
Definition: compile.h:60
Definition: AprilTagPoseEstimator.h:15
Eigen::Matrix< double, Rows, Cols, Options, MaxRows, MaxCols > Matrixd
Definition: EigenCore.h:21
Eigen::Vector< double, Size > Vectord
Definition: EigenCore.h:12
This class contains all the information about our observer at a given time.
Definition: KalmanFilterLatencyCompensator.h:33
Vectord< Inputs > inputs
The inputs.
Definition: KalmanFilterLatencyCompensator.h:39
Vectord< States > xHat
The state estimate.
Definition: KalmanFilterLatencyCompensator.h:35
Vectord< Outputs > localMeasurements
The local measurements.
Definition: KalmanFilterLatencyCompensator.h:41
Matrixd< States, States > squareRootErrorCovariances
The square root error covariance.
Definition: KalmanFilterLatencyCompensator.h:37
ObserverSnapshot(const KalmanFilterType &observer, const Vectord< Inputs > &u, const Vectord< Outputs > &localY)
Definition: KalmanFilterLatencyCompensator.h:43
#define S(label, offset, message)
Definition: Errors.h:119