001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.numbers.N1; 011import java.util.ArrayList; 012import java.util.List; 013import java.util.Map; 014import java.util.function.BiConsumer; 015 016/** 017 * This class incorporates time-delayed measurements into a Kalman filter's state estimate. 018 * 019 * @param <S> The number of states. 020 * @param <I> The number of inputs. 021 * @param <O> The number of outputs. 022 */ 023public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> { 024 private static final int kMaxPastObserverStates = 300; 025 026 private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots; 027 028 /** Default constructor. */ 029 KalmanFilterLatencyCompensator() { 030 m_pastObserverSnapshots = new ArrayList<>(); 031 } 032 033 /** Clears the observer snapshot buffer. */ 034 public void reset() { 035 m_pastObserverSnapshots.clear(); 036 } 037 038 /** 039 * Add past observer states to the observer snapshots list. 040 * 041 * @param observer The observer. 042 * @param u The input at the timestamp. 043 * @param localY The local output at the timestamp 044 * @param timestampSeconds The timestamp of the state. 045 */ 046 public void addObserverState( 047 KalmanTypeFilter<S, I, O> observer, 048 Matrix<I, N1> u, 049 Matrix<O, N1> localY, 050 double timestampSeconds) { 051 m_pastObserverSnapshots.add( 052 Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY))); 053 054 if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) { 055 m_pastObserverSnapshots.remove(0); 056 } 057 } 058 059 /** 060 * Add past global measurements (such as from vision)to the estimator. 061 * 062 * @param <R> The rows in the global measurement vector. 063 * @param rows The rows in the global measurement vector. 064 * @param observer The observer to apply the past global measurement. 065 * @param nominalDtSeconds The nominal timestep. 066 * @param y The measurement. 067 * @param globalMeasurementCorrect The function take calls correct() on the observer. 068 * @param timestampSeconds The timestamp of the measurement. 069 */ 070 public <R extends Num> void applyPastGlobalMeasurement( 071 Nat<R> rows, 072 KalmanTypeFilter<S, I, O> observer, 073 double nominalDtSeconds, 074 Matrix<R, N1> y, 075 BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect, 076 double timestampSeconds) { 077 if (m_pastObserverSnapshots.isEmpty()) { 078 // State map was empty, which means that we got a past measurement right at startup. The only 079 // thing we can really do is ignore the measurement. 080 return; 081 } 082 083 // Use a less verbose name for timestamp 084 double timestamp = timestampSeconds; 085 086 int maxIdx = m_pastObserverSnapshots.size() - 1; 087 int low = 0; 088 int high = maxIdx; 089 090 // Perform a binary search to find the index of first snapshot whose 091 // timestamp is greater than or equal to the global measurement timestamp 092 while (low != high) { 093 int mid = (low + high) / 2; 094 if (m_pastObserverSnapshots.get(mid).getKey() < timestamp) { 095 // This index and everything under it are less than the requested timestamp. Therefore, we 096 // can discard them. 097 low = mid + 1; 098 } else { 099 // t is at least as large as the element at this index. This means that anything after it 100 // cannot be what we are looking for. 101 high = mid; 102 } 103 } 104 105 int indexOfClosestEntry; 106 107 if (low == 0) { 108 // If the global measurement is older than any snapshot, throw out the 109 // measurement because there's no state estimate into which to incorporate 110 // the measurement 111 if (timestamp < m_pastObserverSnapshots.get(low).getKey()) { 112 return; 113 } 114 115 // If the first snapshot has same timestamp as the global measurement, use 116 // that snapshot 117 indexOfClosestEntry = 0; 118 } else if (low == maxIdx && m_pastObserverSnapshots.get(low).getKey() < timestamp) { 119 // If all snapshots are older than the global measurement, use the newest 120 // snapshot 121 indexOfClosestEntry = maxIdx; 122 } else { 123 // Index of snapshot taken after the global measurement 124 int nextIdx = low; 125 126 // Index of snapshot taken before the global measurement. Since we already 127 // handled the case where the index points to the first snapshot, this 128 // computation is guaranteed to be non-negative. 129 int prevIdx = nextIdx - 1; 130 131 // Find the snapshot closest in time to global measurement 132 double prevTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(prevIdx).getKey()); 133 double nextTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(nextIdx).getKey()); 134 indexOfClosestEntry = prevTimeDiff <= nextTimeDiff ? prevIdx : nextIdx; 135 } 136 137 double lastTimestamp = 138 m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds; 139 140 // We will now go back in time to the state of the system at the time when 141 // the measurement was captured. We will reset the observer to that state, 142 // and apply correction based on the measurement. Then, we will go back 143 // through all observer states until the present and apply past inputs to 144 // get the present estimated state. 145 for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) { 146 var key = m_pastObserverSnapshots.get(i).getKey(); 147 var snapshot = m_pastObserverSnapshots.get(i).getValue(); 148 149 if (i == indexOfClosestEntry) { 150 observer.setP(snapshot.errorCovariances); 151 observer.setXhat(snapshot.xHat); 152 } 153 154 observer.predict(snapshot.inputs, key - lastTimestamp); 155 observer.correct(snapshot.inputs, snapshot.localMeasurements); 156 157 if (i == indexOfClosestEntry) { 158 // Note that the measurement is at a timestep close but probably not exactly equal to the 159 // timestep for which we called predict. 160 // This makes the assumption that the dt is small enough that the difference between the 161 // measurement time and the time that the inputs were captured at is very small. 162 globalMeasurementCorrect.accept(snapshot.inputs, y); 163 } 164 lastTimestamp = key; 165 166 m_pastObserverSnapshots.set( 167 i, 168 Map.entry( 169 key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements))); 170 } 171 } 172 173 /** This class contains all the information about our observer at a given time. */ 174 public final class ObserverSnapshot { 175 /** The state estimate. */ 176 public final Matrix<S, N1> xHat; 177 178 /** The error covariance. */ 179 public final Matrix<S, S> errorCovariances; 180 181 /** The inputs. */ 182 public final Matrix<I, N1> inputs; 183 184 /** The local measurements. */ 185 public final Matrix<O, N1> localMeasurements; 186 187 private ObserverSnapshot( 188 KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) { 189 this.xHat = observer.getXhat(); 190 this.errorCovariances = observer.getP(); 191 192 inputs = u; 193 localMeasurements = localY; 194 } 195 } 196}