001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Num;
009import edu.wpi.first.math.numbers.N1;
010import java.util.function.BiFunction;
011import java.util.function.DoubleFunction;
012import java.util.function.Function;
013
014public final class NumericalIntegration {
015  private NumericalIntegration() {
016    // utility Class
017  }
018
019  /**
020   * Performs Runge Kutta integration (4th order).
021   *
022   * @param f The function to integrate, which takes one argument x.
023   * @param x The initial value of x.
024   * @param dtSeconds The time over which to integrate.
025   * @return the integration of dx/dt = f(x) for dt.
026   */
027  @SuppressWarnings("overloads")
028  public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) {
029    final var h = dtSeconds;
030    final var k1 = f.apply(x);
031    final var k2 = f.apply(x + h * k1 * 0.5);
032    final var k3 = f.apply(x + h * k2 * 0.5);
033    final var k4 = f.apply(x + h * k3);
034
035    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
036  }
037
038  /**
039   * Performs Runge Kutta integration (4th order).
040   *
041   * @param f The function to integrate. It must take two arguments x and u.
042   * @param x The initial value of x.
043   * @param u The value u held constant over the integration period.
044   * @param dtSeconds The time over which to integrate.
045   * @return The result of Runge Kutta integration (4th order).
046   */
047  @SuppressWarnings("overloads")
048  public static double rk4(
049      BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) {
050    final var h = dtSeconds;
051
052    final var k1 = f.apply(x, u);
053    final var k2 = f.apply(x + h * k1 * 0.5, u);
054    final var k3 = f.apply(x + h * k2 * 0.5, u);
055    final var k4 = f.apply(x + h * k3, u);
056
057    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
058  }
059
060  /**
061   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
062   *
063   * @param <States> A Num representing the states of the system to integrate.
064   * @param <Inputs> A Num representing the inputs of the system to integrate.
065   * @param f The function to integrate. It must take two arguments x and u.
066   * @param x The initial value of x.
067   * @param u The value u held constant over the integration period.
068   * @param dtSeconds The time over which to integrate.
069   * @return the integration of dx/dt = f(x, u) for dt.
070   */
071  @SuppressWarnings("overloads")
072  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
073      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
074      Matrix<States, N1> x,
075      Matrix<Inputs, N1> u,
076      double dtSeconds) {
077    final var h = dtSeconds;
078
079    Matrix<States, N1> k1 = f.apply(x, u);
080    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
081    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
082    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
083
084    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
085  }
086
087  /**
088   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
089   *
090   * @param <States> A Num prepresenting the states of the system.
091   * @param f The function to integrate. It must take one argument x.
092   * @param x The initial value of x.
093   * @param dtSeconds The time over which to integrate.
094   * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
095   */
096  @SuppressWarnings("overloads")
097  public static <States extends Num> Matrix<States, N1> rk4(
098      Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) {
099    final var h = dtSeconds;
100
101    Matrix<States, N1> k1 = f.apply(x);
102    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
103    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
104    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
105
106    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
107  }
108
109  /**
110   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
111   * error is 1e-6.
112   *
113   * @param <States> A Num representing the states of the system to integrate.
114   * @param <Inputs> A Num representing the inputs of the system to integrate.
115   * @param f The function to integrate. It must take two arguments x and u.
116   * @param x The initial value of x.
117   * @param u The value u held constant over the integration period.
118   * @param dtSeconds The time over which to integrate.
119   * @return the integration of dx/dt = f(x, u) for dt.
120   */
121  @SuppressWarnings("overloads")
122  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
123      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
124      Matrix<States, N1> x,
125      Matrix<Inputs, N1> u,
126      double dtSeconds) {
127    return rkdp(f, x, u, dtSeconds, 1e-6);
128  }
129
130  /**
131   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
132   *
133   * @param <States> A Num representing the states of the system to integrate.
134   * @param <Inputs> A Num representing the inputs of the system to integrate.
135   * @param f The function to integrate. It must take two arguments x and u.
136   * @param x The initial value of x.
137   * @param u The value u held constant over the integration period.
138   * @param dtSeconds The time over which to integrate.
139   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
140   * @return the integration of dx/dt = f(x, u) for dt.
141   */
142  @SuppressWarnings("overloads")
143  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
144      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
145      Matrix<States, N1> x,
146      Matrix<Inputs, N1> u,
147      double dtSeconds,
148      double maxError) {
149    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
150    // Butcher tableau the following arrays came from.
151
152    // final double[6][6]
153    final double[][] A = {
154      {1.0 / 5.0},
155      {3.0 / 40.0, 9.0 / 40.0},
156      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
157      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
158      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
159      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
160    };
161
162    // final double[7]
163    final double[] b1 = {
164      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
165    };
166
167    // final double[7]
168    final double[] b2 = {
169      5179.0 / 57600.0,
170      0.0,
171      7571.0 / 16695.0,
172      393.0 / 640.0,
173      -92097.0 / 339200.0,
174      187.0 / 2100.0,
175      1.0 / 40.0
176    };
177
178    Matrix<States, N1> newX;
179    double truncationError;
180
181    double dtElapsed = 0.0;
182    double h = dtSeconds;
183
184    // Loop until we've gotten to our desired dt
185    while (dtElapsed < dtSeconds) {
186      do {
187        // Only allow us to advance up to the dt remaining
188        h = Math.min(h, dtSeconds - dtElapsed);
189
190        var k1 = f.apply(x, u);
191        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
192        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
193        var k4 =
194            f.apply(
195                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
196                u);
197        var k5 =
198            f.apply(
199                x.plus(
200                    k1.times(A[3][0])
201                        .plus(k2.times(A[3][1]))
202                        .plus(k3.times(A[3][2]))
203                        .plus(k4.times(A[3][3]))
204                        .times(h)),
205                u);
206        var k6 =
207            f.apply(
208                x.plus(
209                    k1.times(A[4][0])
210                        .plus(k2.times(A[4][1]))
211                        .plus(k3.times(A[4][2]))
212                        .plus(k4.times(A[4][3]))
213                        .plus(k5.times(A[4][4]))
214                        .times(h)),
215                u);
216
217        // Since the final row of A and the array b1 have the same coefficients
218        // and k7 has no effect on newX, we can reuse the calculation.
219        newX =
220            x.plus(
221                k1.times(A[5][0])
222                    .plus(k2.times(A[5][1]))
223                    .plus(k3.times(A[5][2]))
224                    .plus(k4.times(A[5][3]))
225                    .plus(k5.times(A[5][4]))
226                    .plus(k6.times(A[5][5]))
227                    .times(h));
228        var k7 = f.apply(newX, u);
229
230        truncationError =
231            (k1.times(b1[0] - b2[0])
232                    .plus(k2.times(b1[1] - b2[1]))
233                    .plus(k3.times(b1[2] - b2[2]))
234                    .plus(k4.times(b1[3] - b2[3]))
235                    .plus(k5.times(b1[4] - b2[4]))
236                    .plus(k6.times(b1[5] - b2[5]))
237                    .plus(k7.times(b1[6] - b2[6]))
238                    .times(h))
239                .normF();
240
241        if (truncationError == 0.0) {
242          h = dtSeconds - dtElapsed;
243        } else {
244          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
245        }
246      } while (truncationError > maxError);
247
248      dtElapsed += h;
249      x = newX;
250    }
251
252    return x;
253  }
254}