001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import java.util.ArrayList;
012import java.util.List;
013import java.util.Map;
014import java.util.function.BiConsumer;
015
016/**
017 * This class incorporates time-delayed measurements into a Kalman filter's state estimate.
018 *
019 * @param <S> The number of states.
020 * @param <I> The number of inputs.
021 * @param <O> The number of outputs.
022 */
023public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> {
024  private static final int kMaxPastObserverStates = 300;
025
026  private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots;
027
028  /** Default constructor. */
029  KalmanFilterLatencyCompensator() {
030    m_pastObserverSnapshots = new ArrayList<>();
031  }
032
033  /** Clears the observer snapshot buffer. */
034  public void reset() {
035    m_pastObserverSnapshots.clear();
036  }
037
038  /**
039   * Add past observer states to the observer snapshots list.
040   *
041   * @param observer The observer.
042   * @param u The input at the timestamp.
043   * @param localY The local output at the timestamp
044   * @param timestampSeconds The timestamp of the state.
045   */
046  public void addObserverState(
047      KalmanTypeFilter<S, I, O> observer,
048      Matrix<I, N1> u,
049      Matrix<O, N1> localY,
050      double timestampSeconds) {
051    m_pastObserverSnapshots.add(
052        Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY)));
053
054    if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
055      m_pastObserverSnapshots.remove(0);
056    }
057  }
058
059  /**
060   * Add past global measurements (such as from vision)to the estimator.
061   *
062   * @param <R> The rows in the global measurement vector.
063   * @param rows The rows in the global measurement vector.
064   * @param observer The observer to apply the past global measurement.
065   * @param nominalDtSeconds The nominal timestep.
066   * @param y The measurement.
067   * @param globalMeasurementCorrect The function take calls correct() on the observer.
068   * @param timestampSeconds The timestamp of the measurement.
069   */
070  public <R extends Num> void applyPastGlobalMeasurement(
071      Nat<R> rows,
072      KalmanTypeFilter<S, I, O> observer,
073      double nominalDtSeconds,
074      Matrix<R, N1> y,
075      BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect,
076      double timestampSeconds) {
077    if (m_pastObserverSnapshots.isEmpty()) {
078      // State map was empty, which means that we got a past measurement right at startup. The only
079      // thing we can really do is ignore the measurement.
080      return;
081    }
082
083    // Use a less verbose name for timestamp
084    double timestamp = timestampSeconds;
085
086    int maxIdx = m_pastObserverSnapshots.size() - 1;
087    int low = 0;
088    int high = maxIdx;
089
090    // Perform a binary search to find the index of first snapshot whose
091    // timestamp is greater than or equal to the global measurement timestamp
092    while (low != high) {
093      int mid = (low + high) / 2;
094      if (m_pastObserverSnapshots.get(mid).getKey() < timestamp) {
095        // This index and everything under it are less than the requested timestamp. Therefore, we
096        // can discard them.
097        low = mid + 1;
098      } else {
099        // t is at least as large as the element at this index. This means that anything after it
100        // cannot be what we are looking for.
101        high = mid;
102      }
103    }
104
105    int indexOfClosestEntry;
106
107    if (low == 0) {
108      // If the global measurement is older than any snapshot, throw out the
109      // measurement because there's no state estimate into which to incorporate
110      // the measurement
111      if (timestamp < m_pastObserverSnapshots.get(low).getKey()) {
112        return;
113      }
114
115      // If the first snapshot has same timestamp as the global measurement, use
116      // that snapshot
117      indexOfClosestEntry = 0;
118    } else if (low == maxIdx && m_pastObserverSnapshots.get(low).getKey() < timestamp) {
119      // If all snapshots are older than the global measurement, use the newest
120      // snapshot
121      indexOfClosestEntry = maxIdx;
122    } else {
123      // Index of snapshot taken after the global measurement
124      int nextIdx = low;
125
126      // Index of snapshot taken before the global measurement. Since we already
127      // handled the case where the index points to the first snapshot, this
128      // computation is guaranteed to be non-negative.
129      int prevIdx = nextIdx - 1;
130
131      // Find the snapshot closest in time to global measurement
132      double prevTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(prevIdx).getKey());
133      double nextTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(nextIdx).getKey());
134      indexOfClosestEntry = prevTimeDiff <= nextTimeDiff ? prevIdx : nextIdx;
135    }
136
137    double lastTimestamp =
138        m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;
139
140    // We will now go back in time to the state of the system at the time when
141    // the measurement was captured. We will reset the observer to that state,
142    // and apply correction based on the measurement. Then, we will go back
143    // through all observer states until the present and apply past inputs to
144    // get the present estimated state.
145    for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) {
146      var key = m_pastObserverSnapshots.get(i).getKey();
147      var snapshot = m_pastObserverSnapshots.get(i).getValue();
148
149      if (i == indexOfClosestEntry) {
150        observer.setP(snapshot.errorCovariances);
151        observer.setXhat(snapshot.xHat);
152      }
153
154      observer.predict(snapshot.inputs, key - lastTimestamp);
155      observer.correct(snapshot.inputs, snapshot.localMeasurements);
156
157      if (i == indexOfClosestEntry) {
158        // Note that the measurement is at a timestep close but probably not exactly equal to the
159        // timestep for which we called predict.
160        // This makes the assumption that the dt is small enough that the difference between the
161        // measurement time and the time that the inputs were captured at is very small.
162        globalMeasurementCorrect.accept(snapshot.inputs, y);
163      }
164      lastTimestamp = key;
165
166      m_pastObserverSnapshots.set(
167          i,
168          Map.entry(
169              key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements)));
170    }
171  }
172
173  /** This class contains all the information about our observer at a given time. */
174  public class ObserverSnapshot {
175    /** The state estimate. */
176    public final Matrix<S, N1> xHat;
177
178    /** The error covariance. */
179    public final Matrix<S, S> errorCovariances;
180
181    /** The inputs. */
182    public final Matrix<I, N1> inputs;
183
184    /** The local measurements. */
185    public final Matrix<O, N1> localMeasurements;
186
187    private ObserverSnapshot(
188        KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) {
189      this.xHat = observer.getXhat();
190      this.errorCovariances = observer.getP();
191
192      inputs = u;
193      localMeasurements = localY;
194    }
195  }
196}