001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Num;
009import edu.wpi.first.math.numbers.N1;
010import java.util.function.BiFunction;
011import java.util.function.DoubleBinaryOperator;
012import java.util.function.DoubleUnaryOperator;
013import java.util.function.UnaryOperator;
014
015/** Numerical integration utilities. */
016public final class NumericalIntegration {
017  private NumericalIntegration() {
018    // utility Class
019  }
020
021  /**
022   * Performs Runge Kutta integration (4th order).
023   *
024   * @param f The function to integrate, which takes one argument x.
025   * @param x The initial value of x.
026   * @param dtSeconds The time over which to integrate.
027   * @return the integration of dx/dt = f(x) for dt.
028   */
029  public static double rk4(DoubleUnaryOperator f, double x, double dtSeconds) {
030    final var h = dtSeconds;
031    final var k1 = f.applyAsDouble(x);
032    final var k2 = f.applyAsDouble(x + h * k1 * 0.5);
033    final var k3 = f.applyAsDouble(x + h * k2 * 0.5);
034    final var k4 = f.applyAsDouble(x + h * k3);
035
036    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
037  }
038
039  /**
040   * Performs Runge Kutta integration (4th order).
041   *
042   * @param f The function to integrate. It must take two arguments x and u.
043   * @param x The initial value of x.
044   * @param u The value u held constant over the integration period.
045   * @param dtSeconds The time over which to integrate.
046   * @return The result of Runge Kutta integration (4th order).
047   */
048  public static double rk4(DoubleBinaryOperator f, double x, double u, double dtSeconds) {
049    final var h = dtSeconds;
050
051    final var k1 = f.applyAsDouble(x, u);
052    final var k2 = f.applyAsDouble(x + h * k1 * 0.5, u);
053    final var k3 = f.applyAsDouble(x + h * k2 * 0.5, u);
054    final var k4 = f.applyAsDouble(x + h * k3, u);
055
056    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
057  }
058
059  /**
060   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
061   *
062   * @param <States> A Num representing the states of the system to integrate.
063   * @param <Inputs> A Num representing the inputs of the system to integrate.
064   * @param f The function to integrate. It must take two arguments x and u.
065   * @param x The initial value of x.
066   * @param u The value u held constant over the integration period.
067   * @param dtSeconds The time over which to integrate.
068   * @return the integration of dx/dt = f(x, u) for dt.
069   */
070  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
071      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
072      Matrix<States, N1> x,
073      Matrix<Inputs, N1> u,
074      double dtSeconds) {
075    final var h = dtSeconds;
076
077    Matrix<States, N1> k1 = f.apply(x, u);
078    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
079    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
080    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
081
082    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
083  }
084
085  /**
086   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
087   *
088   * @param <States> A Num representing the states of the system.
089   * @param f The function to integrate. It must take one argument x.
090   * @param x The initial value of x.
091   * @param dtSeconds The time over which to integrate.
092   * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
093   */
094  public static <States extends Num> Matrix<States, N1> rk4(
095      UnaryOperator<Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) {
096    final var h = dtSeconds;
097
098    Matrix<States, N1> k1 = f.apply(x);
099    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
100    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
101    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
102
103    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
104  }
105
106  /**
107   * Performs 4th order Runge-Kutta integration of dx/dt = f(t, y) for dt.
108   *
109   * @param <Rows> Rows in y.
110   * @param <Cols> Columns in y.
111   * @param f The function to integrate. It must take two arguments t and y.
112   * @param t The initial value of t.
113   * @param y The initial value of y.
114   * @param dtSeconds The time over which to integrate.
115   * @return the integration of dx/dt = f(x) for dt.
116   */
117  public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rk4(
118      BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f,
119      double t,
120      Matrix<Rows, Cols> y,
121      double dtSeconds) {
122    final var h = dtSeconds;
123
124    Matrix<Rows, Cols> k1 = f.apply(t, y);
125    Matrix<Rows, Cols> k2 = f.apply(t + dtSeconds * 0.5, y.plus(k1.times(h * 0.5)));
126    Matrix<Rows, Cols> k3 = f.apply(t + dtSeconds * 0.5, y.plus(k2.times(h * 0.5)));
127    Matrix<Rows, Cols> k4 = f.apply(t + dtSeconds, y.plus(k3.times(h)));
128
129    return y.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
130  }
131
132  /**
133   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
134   * error is 1e-6.
135   *
136   * @param <States> A Num representing the states of the system to integrate.
137   * @param <Inputs> A Num representing the inputs of the system to integrate.
138   * @param f The function to integrate. It must take two arguments x and u.
139   * @param x The initial value of x.
140   * @param u The value u held constant over the integration period.
141   * @param dtSeconds The time over which to integrate.
142   * @return the integration of dx/dt = f(x, u) for dt.
143   */
144  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
145      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
146      Matrix<States, N1> x,
147      Matrix<Inputs, N1> u,
148      double dtSeconds) {
149    return rkdp(f, x, u, dtSeconds, 1e-6);
150  }
151
152  /**
153   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
154   *
155   * @param <States> A Num representing the states of the system to integrate.
156   * @param <Inputs> A Num representing the inputs of the system to integrate.
157   * @param f The function to integrate. It must take two arguments x and u.
158   * @param x The initial value of x.
159   * @param u The value u held constant over the integration period.
160   * @param dtSeconds The time over which to integrate.
161   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
162   * @return the integration of dx/dt = f(x, u) for dt.
163   */
164  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
165      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
166      Matrix<States, N1> x,
167      Matrix<Inputs, N1> u,
168      double dtSeconds,
169      double maxError) {
170    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
171    // Butcher tableau the following arrays came from.
172
173    // final double[6][6]
174    final double[][] A = {
175      {1.0 / 5.0},
176      {3.0 / 40.0, 9.0 / 40.0},
177      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
178      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
179      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
180      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
181    };
182
183    // final double[7]
184    final double[] b1 = {
185      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
186    };
187
188    // final double[7]
189    final double[] b2 = {
190      5179.0 / 57600.0,
191      0.0,
192      7571.0 / 16695.0,
193      393.0 / 640.0,
194      -92097.0 / 339200.0,
195      187.0 / 2100.0,
196      1.0 / 40.0
197    };
198
199    Matrix<States, N1> newX;
200    double truncationError;
201
202    double dtElapsed = 0.0;
203    double h = dtSeconds;
204
205    // Loop until we've gotten to our desired dt
206    while (dtElapsed < dtSeconds) {
207      do {
208        // Only allow us to advance up to the dt remaining
209        h = Math.min(h, dtSeconds - dtElapsed);
210
211        var k1 = f.apply(x, u);
212        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
213        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
214        var k4 =
215            f.apply(
216                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
217                u);
218        var k5 =
219            f.apply(
220                x.plus(
221                    k1.times(A[3][0])
222                        .plus(k2.times(A[3][1]))
223                        .plus(k3.times(A[3][2]))
224                        .plus(k4.times(A[3][3]))
225                        .times(h)),
226                u);
227        var k6 =
228            f.apply(
229                x.plus(
230                    k1.times(A[4][0])
231                        .plus(k2.times(A[4][1]))
232                        .plus(k3.times(A[4][2]))
233                        .plus(k4.times(A[4][3]))
234                        .plus(k5.times(A[4][4]))
235                        .times(h)),
236                u);
237
238        // Since the final row of A and the array b1 have the same coefficients
239        // and k7 has no effect on newX, we can reuse the calculation.
240        newX =
241            x.plus(
242                k1.times(A[5][0])
243                    .plus(k2.times(A[5][1]))
244                    .plus(k3.times(A[5][2]))
245                    .plus(k4.times(A[5][3]))
246                    .plus(k5.times(A[5][4]))
247                    .plus(k6.times(A[5][5]))
248                    .times(h));
249        var k7 = f.apply(newX, u);
250
251        truncationError =
252            (k1.times(b1[0] - b2[0])
253                    .plus(k2.times(b1[1] - b2[1]))
254                    .plus(k3.times(b1[2] - b2[2]))
255                    .plus(k4.times(b1[3] - b2[3]))
256                    .plus(k5.times(b1[4] - b2[4]))
257                    .plus(k6.times(b1[5] - b2[5]))
258                    .plus(k7.times(b1[6] - b2[6]))
259                    .times(h))
260                .normF();
261
262        if (truncationError == 0.0) {
263          h = dtSeconds - dtElapsed;
264        } else {
265          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
266        }
267      } while (truncationError > maxError);
268
269      dtElapsed += h;
270      x = newX;
271    }
272
273    return x;
274  }
275
276  /**
277   * Performs adaptive Dormand-Prince integration of dx/dt = f(t, y) for dt.
278   *
279   * @param <Rows> Rows in y.
280   * @param <Cols> Columns in y.
281   * @param f The function to integrate. It must take two arguments t and y.
282   * @param t The initial value of t.
283   * @param y The initial value of y.
284   * @param dtSeconds The time over which to integrate.
285   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
286   * @return the integration of dx/dt = f(x, u) for dt.
287   */
288  public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rkdp(
289      BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f,
290      double t,
291      Matrix<Rows, Cols> y,
292      double dtSeconds,
293      double maxError) {
294    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
295    // Butcher tableau the following arrays came from.
296
297    // final double[6][6]
298    final double[][] A = {
299      {1.0 / 5.0},
300      {3.0 / 40.0, 9.0 / 40.0},
301      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
302      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
303      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
304      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
305    };
306
307    // final double[7]
308    final double[] b1 = {
309      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
310    };
311
312    // final double[7]
313    final double[] b2 = {
314      5179.0 / 57600.0,
315      0.0,
316      7571.0 / 16695.0,
317      393.0 / 640.0,
318      -92097.0 / 339200.0,
319      187.0 / 2100.0,
320      1.0 / 40.0
321    };
322
323    // final double[6]
324    final double[] c = {1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0};
325
326    Matrix<Rows, Cols> newY;
327    double truncationError;
328
329    double dtElapsed = 0.0;
330    double h = dtSeconds;
331
332    // Loop until we've gotten to our desired dt
333    while (dtElapsed < dtSeconds) {
334      do {
335        // Only allow us to advance up to the dt remaining
336        h = Math.min(h, dtSeconds - dtElapsed);
337
338        var k1 = f.apply(t, y);
339        var k2 = f.apply(t + h * c[0], y.plus(k1.times(A[0][0]).times(h)));
340        var k3 = f.apply(t + h * c[1], y.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)));
341        var k4 =
342            f.apply(
343                t + h * c[2],
344                y.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)));
345        var k5 =
346            f.apply(
347                t + h * c[3],
348                y.plus(
349                    k1.times(A[3][0])
350                        .plus(k2.times(A[3][1]))
351                        .plus(k3.times(A[3][2]))
352                        .plus(k4.times(A[3][3]))
353                        .times(h)));
354        var k6 =
355            f.apply(
356                t + h * c[4],
357                y.plus(
358                    k1.times(A[4][0])
359                        .plus(k2.times(A[4][1]))
360                        .plus(k3.times(A[4][2]))
361                        .plus(k4.times(A[4][3]))
362                        .plus(k5.times(A[4][4]))
363                        .times(h)));
364
365        // Since the final row of A and the array b1 have the same coefficients
366        // and k7 has no effect on newY, we can reuse the calculation.
367        newY =
368            y.plus(
369                k1.times(A[5][0])
370                    .plus(k2.times(A[5][1]))
371                    .plus(k3.times(A[5][2]))
372                    .plus(k4.times(A[5][3]))
373                    .plus(k5.times(A[5][4]))
374                    .plus(k6.times(A[5][5]))
375                    .times(h));
376        var k7 = f.apply(t + h * c[5], newY);
377
378        truncationError =
379            (k1.times(b1[0] - b2[0])
380                    .plus(k2.times(b1[1] - b2[1]))
381                    .plus(k3.times(b1[2] - b2[2]))
382                    .plus(k4.times(b1[3] - b2[3]))
383                    .plus(k5.times(b1[4] - b2[4]))
384                    .plus(k6.times(b1[5] - b2[5]))
385                    .plus(k7.times(b1[6] - b2[6]))
386                    .times(h))
387                .normF();
388
389        if (truncationError == 0.0) {
390          h = dtSeconds - dtElapsed;
391        } else {
392          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
393        }
394      } while (truncationError > maxError);
395
396      dtElapsed += h;
397      y = newY;
398    }
399
400    return y;
401  }
402}