001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import java.util.ArrayList;
012import java.util.List;
013import java.util.Map;
014import java.util.function.BiConsumer;
015
016public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> {
017  private static final int kMaxPastObserverStates = 300;
018
019  private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots;
020
021  KalmanFilterLatencyCompensator() {
022    m_pastObserverSnapshots = new ArrayList<>();
023  }
024
025  /** Clears the observer snapshot buffer. */
026  public void reset() {
027    m_pastObserverSnapshots.clear();
028  }
029
030  /**
031   * Add past observer states to the observer snapshots list.
032   *
033   * @param observer The observer.
034   * @param u The input at the timestamp.
035   * @param localY The local output at the timestamp
036   * @param timestampSeconds The timestamp of the state.
037   */
038  public void addObserverState(
039      KalmanTypeFilter<S, I, O> observer,
040      Matrix<I, N1> u,
041      Matrix<O, N1> localY,
042      double timestampSeconds) {
043    m_pastObserverSnapshots.add(
044        Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY)));
045
046    if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
047      m_pastObserverSnapshots.remove(0);
048    }
049  }
050
051  /**
052   * Add past global measurements (such as from vision)to the estimator.
053   *
054   * @param <R> The rows in the global measurement vector.
055   * @param rows The rows in the global measurement vector.
056   * @param observer The observer to apply the past global measurement.
057   * @param nominalDtSeconds The nominal timestep.
058   * @param y The measurement.
059   * @param globalMeasurementCorrect The function take calls correct() on the observer.
060   * @param timestampSeconds The timestamp of the measurement.
061   */
062  public <R extends Num> void applyPastGlobalMeasurement(
063      Nat<R> rows,
064      KalmanTypeFilter<S, I, O> observer,
065      double nominalDtSeconds,
066      Matrix<R, N1> y,
067      BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect,
068      double timestampSeconds) {
069    if (m_pastObserverSnapshots.isEmpty()) {
070      // State map was empty, which means that we got a past measurement right at startup. The only
071      // thing we can really do is ignore the measurement.
072      return;
073    }
074
075    // Use a less verbose name for timestamp
076    double timestamp = timestampSeconds;
077
078    int maxIdx = m_pastObserverSnapshots.size() - 1;
079    int low = 0;
080    int high = maxIdx;
081
082    // Perform a binary search to find the index of first snapshot whose
083    // timestamp is greater than or equal to the global measurement timestamp
084    while (low != high) {
085      int mid = (low + high) / 2;
086      if (m_pastObserverSnapshots.get(mid).getKey() < timestamp) {
087        // This index and everything under it are less than the requested timestamp. Therefore, we
088        // can discard them.
089        low = mid + 1;
090      } else {
091        // t is at least as large as the element at this index. This means that anything after it
092        // cannot be what we are looking for.
093        high = mid;
094      }
095    }
096
097    int indexOfClosestEntry;
098
099    if (low == 0) {
100      // If the global measurement is older than any snapshot, throw out the
101      // measurement because there's no state estimate into which to incorporate
102      // the measurement
103      if (timestamp < m_pastObserverSnapshots.get(low).getKey()) {
104        return;
105      }
106
107      // If the first snapshot has same timestamp as the global measurement, use
108      // that snapshot
109      indexOfClosestEntry = 0;
110    } else if (low == maxIdx && m_pastObserverSnapshots.get(low).getKey() < timestamp) {
111      // If all snapshots are older than the global measurement, use the newest
112      // snapshot
113      indexOfClosestEntry = maxIdx;
114    } else {
115      // Index of snapshot taken after the global measurement
116      int nextIdx = low;
117
118      // Index of snapshot taken before the global measurement. Since we already
119      // handled the case where the index points to the first snapshot, this
120      // computation is guaranteed to be non-negative.
121      int prevIdx = nextIdx - 1;
122
123      // Find the snapshot closest in time to global measurement
124      double prevTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(prevIdx).getKey());
125      double nextTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(nextIdx).getKey());
126      indexOfClosestEntry = prevTimeDiff <= nextTimeDiff ? prevIdx : nextIdx;
127    }
128
129    double lastTimestamp =
130        m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;
131
132    // We will now go back in time to the state of the system at the time when
133    // the measurement was captured. We will reset the observer to that state,
134    // and apply correction based on the measurement. Then, we will go back
135    // through all observer states until the present and apply past inputs to
136    // get the present estimated state.
137    for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) {
138      var key = m_pastObserverSnapshots.get(i).getKey();
139      var snapshot = m_pastObserverSnapshots.get(i).getValue();
140
141      if (i == indexOfClosestEntry) {
142        observer.setP(snapshot.errorCovariances);
143        observer.setXhat(snapshot.xHat);
144      }
145
146      observer.predict(snapshot.inputs, key - lastTimestamp);
147      observer.correct(snapshot.inputs, snapshot.localMeasurements);
148
149      if (i == indexOfClosestEntry) {
150        // Note that the measurement is at a timestep close but probably not exactly equal to the
151        // timestep for which we called predict.
152        // This makes the assumption that the dt is small enough that the difference between the
153        // measurement time and the time that the inputs were captured at is very small.
154        globalMeasurementCorrect.accept(snapshot.inputs, y);
155      }
156      lastTimestamp = key;
157
158      m_pastObserverSnapshots.set(
159          i,
160          Map.entry(
161              key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements)));
162    }
163  }
164
165  /** This class contains all the information about our observer at a given time. */
166  public class ObserverSnapshot {
167    public final Matrix<S, N1> xHat;
168    public final Matrix<S, S> errorCovariances;
169    public final Matrix<I, N1> inputs;
170    public final Matrix<O, N1> localMeasurements;
171
172    private ObserverSnapshot(
173        KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) {
174      this.xHat = observer.getXhat();
175      this.errorCovariances = observer.getP();
176
177      inputs = u;
178      localMeasurements = localY;
179    }
180  }
181}