001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.system;
006
007import java.util.function.BiFunction;
008import java.util.function.DoubleBinaryOperator;
009import java.util.function.DoubleUnaryOperator;
010import java.util.function.UnaryOperator;
011import org.ejml.simple.SimpleMatrix;
012import org.wpilib.math.linalg.Matrix;
013import org.wpilib.math.numbers.N1;
014import org.wpilib.math.util.Num;
015
016/** Numerical integration utilities. */
017public final class NumericalIntegration {
018  private NumericalIntegration() {
019    // utility Class
020  }
021
022  /**
023   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
024   *
025   * @param f The function to integrate. It must take one argument x.
026   * @param x The initial value of x.
027   * @param dt The time over which to integrate in seconds.
028   * @return the integration of dx/dt = f(x) for dt.
029   */
030  public static double rk4(DoubleUnaryOperator f, double x, double dt) {
031    final var h = dt;
032    final var k1 = f.applyAsDouble(x);
033    final var k2 = f.applyAsDouble(x + h * k1 * 0.5);
034    final var k3 = f.applyAsDouble(x + h * k2 * 0.5);
035    final var k4 = f.applyAsDouble(x + h * k3);
036
037    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
038  }
039
040  /**
041   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
042   *
043   * @param f The function to integrate. It must take two arguments x and u.
044   * @param x The initial value of x.
045   * @param u The value u held constant over the integration period.
046   * @param dt The time over which to integrate in seconds.
047   * @return the integration of dx/dt = f(x, u) for dt.
048   */
049  public static double rk4(DoubleBinaryOperator f, double x, double u, double dt) {
050    final var h = dt;
051
052    final var k1 = f.applyAsDouble(x, u);
053    final var k2 = f.applyAsDouble(x + h * k1 * 0.5, u);
054    final var k3 = f.applyAsDouble(x + h * k2 * 0.5, u);
055    final var k4 = f.applyAsDouble(x + h * k3, u);
056
057    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
058  }
059
060  /**
061   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
062   *
063   * @param f The function to integrate. It must take two arguments x and u.
064   * @param x The initial value of x.
065   * @param u The value u held constant over the integration period.
066   * @param dt The time over which to integrate.
067   * @return the integration of dx/dt = f(x, u) for dt.
068   */
069  public static SimpleMatrix rk4(
070      BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> f,
071      SimpleMatrix x,
072      SimpleMatrix u,
073      double dt) {
074    var h = dt;
075
076    var k1 = f.apply(x, u);
077    var k2 = f.apply(x.plus(k1.scale(h * 0.5)), u);
078    var k3 = f.apply(x.plus(k2.scale(h * 0.5)), u);
079    var k4 = f.apply(x.plus(k3.scale(h)), u);
080
081    return x.plus(k1.plus(k2.scale(2.0)).plus(k3.scale(2.0)).plus(k4).scale(h / 6.0));
082  }
083
084  /**
085   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
086   *
087   * @param <States> A Num representing the states of the system to integrate.
088   * @param <Inputs> A Num representing the inputs of the system to integrate.
089   * @param f The function to integrate. It must take two arguments x and u.
090   * @param x The initial value of x.
091   * @param u The value u held constant over the integration period.
092   * @param dt The time over which to integrate in seconds.
093   * @return the integration of dx/dt = f(x, u) for dt.
094   */
095  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
096      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
097      Matrix<States, N1> x,
098      Matrix<Inputs, N1> u,
099      double dt) {
100    final var h = dt;
101
102    Matrix<States, N1> k1 = f.apply(x, u);
103    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
104    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
105    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
106
107    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
108  }
109
110  /**
111   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
112   *
113   * @param <States> A Num representing the states of the system.
114   * @param f The function to integrate. It must take one argument x.
115   * @param x The initial value of x.
116   * @param dt The time over which to integrate in seconds.
117   * @return the integration of dx/dt = f(x) for dt.
118   */
119  public static <States extends Num> Matrix<States, N1> rk4(
120      UnaryOperator<Matrix<States, N1>> f, Matrix<States, N1> x, double dt) {
121    final var h = dt;
122
123    Matrix<States, N1> k1 = f.apply(x);
124    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
125    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
126    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
127
128    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
129  }
130
131  /**
132   * Performs 4th order Runge-Kutta integration of dx/dt = f(t, y) for dt.
133   *
134   * @param <Rows> Rows in y.
135   * @param <Cols> Columns in y.
136   * @param f The function to integrate. It must take two arguments t and y.
137   * @param t The initial value of t.
138   * @param y The initial value of y.
139   * @param dt The time over which to integrate in seconds.
140   * @return the integration of dx/dt = f(x) for dt.
141   */
142  public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rk4(
143      BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f,
144      double t,
145      Matrix<Rows, Cols> y,
146      double dt) {
147    final var h = dt;
148
149    Matrix<Rows, Cols> k1 = f.apply(t, y);
150    Matrix<Rows, Cols> k2 = f.apply(t + dt * 0.5, y.plus(k1.times(h * 0.5)));
151    Matrix<Rows, Cols> k3 = f.apply(t + dt * 0.5, y.plus(k2.times(h * 0.5)));
152    Matrix<Rows, Cols> k4 = f.apply(t + dt, y.plus(k3.times(h)));
153
154    return y.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
155  }
156
157  /**
158   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
159   * error is 1e-6.
160   *
161   * @param <States> A Num representing the states of the system to integrate.
162   * @param <Inputs> A Num representing the inputs of the system to integrate.
163   * @param f The function to integrate. It must take two arguments x and u.
164   * @param x The initial value of x.
165   * @param u The value u held constant over the integration period.
166   * @param dt The time over which to integrate in seconds.
167   * @return the integration of dx/dt = f(x, u) for dt.
168   */
169  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
170      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
171      Matrix<States, N1> x,
172      Matrix<Inputs, N1> u,
173      double dt) {
174    return rkdp(f, x, u, dt, 1e-6);
175  }
176
177  /**
178   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
179   *
180   * @param <States> A Num representing the states of the system to integrate.
181   * @param <Inputs> A Num representing the inputs of the system to integrate.
182   * @param f The function to integrate. It must take two arguments x and u.
183   * @param x The initial value of x.
184   * @param u The value u held constant over the integration period.
185   * @param dt The time over which to integrate in seconds.
186   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
187   * @return the integration of dx/dt = f(x, u) for dt.
188   */
189  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
190      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
191      Matrix<States, N1> x,
192      Matrix<Inputs, N1> u,
193      double dt,
194      double maxError) {
195    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
196    // Butcher tableau the following arrays came from.
197
198    // final double[6][6]
199    final double[][] A = {
200      {1.0 / 5.0},
201      {3.0 / 40.0, 9.0 / 40.0},
202      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
203      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
204      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
205      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
206    };
207
208    // final double[7]
209    final double[] b1 = {
210      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
211    };
212
213    // final double[7]
214    final double[] b2 = {
215      5179.0 / 57600.0,
216      0.0,
217      7571.0 / 16695.0,
218      393.0 / 640.0,
219      -92097.0 / 339200.0,
220      187.0 / 2100.0,
221      1.0 / 40.0
222    };
223
224    Matrix<States, N1> newX;
225    double truncationError;
226
227    double dtElapsed = 0.0;
228    double h = dt;
229
230    // Loop until we've gotten to our desired dt
231    while (dtElapsed < dt) {
232      do {
233        // Only allow us to advance up to the dt remaining
234        h = Math.min(h, dt - dtElapsed);
235
236        var k1 = f.apply(x, u);
237        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
238        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
239        var k4 =
240            f.apply(
241                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
242                u);
243        var k5 =
244            f.apply(
245                x.plus(
246                    k1.times(A[3][0])
247                        .plus(k2.times(A[3][1]))
248                        .plus(k3.times(A[3][2]))
249                        .plus(k4.times(A[3][3]))
250                        .times(h)),
251                u);
252        var k6 =
253            f.apply(
254                x.plus(
255                    k1.times(A[4][0])
256                        .plus(k2.times(A[4][1]))
257                        .plus(k3.times(A[4][2]))
258                        .plus(k4.times(A[4][3]))
259                        .plus(k5.times(A[4][4]))
260                        .times(h)),
261                u);
262
263        // Since the final row of A and the array b1 have the same coefficients
264        // and k7 has no effect on newX, we can reuse the calculation.
265        newX =
266            x.plus(
267                k1.times(A[5][0])
268                    .plus(k2.times(A[5][1]))
269                    .plus(k3.times(A[5][2]))
270                    .plus(k4.times(A[5][3]))
271                    .plus(k5.times(A[5][4]))
272                    .plus(k6.times(A[5][5]))
273                    .times(h));
274        var k7 = f.apply(newX, u);
275
276        truncationError =
277            (k1.times(b1[0] - b2[0])
278                    .plus(k2.times(b1[1] - b2[1]))
279                    .plus(k3.times(b1[2] - b2[2]))
280                    .plus(k4.times(b1[3] - b2[3]))
281                    .plus(k5.times(b1[4] - b2[4]))
282                    .plus(k6.times(b1[5] - b2[5]))
283                    .plus(k7.times(b1[6] - b2[6]))
284                    .times(h))
285                .normF();
286
287        if (truncationError == 0.0) {
288          h = dt - dtElapsed;
289        } else {
290          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
291        }
292      } while (truncationError > maxError);
293
294      dtElapsed += h;
295      x = newX;
296    }
297
298    return x;
299  }
300
301  /**
302   * Performs adaptive Dormand-Prince integration of dx/dt = f(t, y) for dt.
303   *
304   * @param <Rows> Rows in y.
305   * @param <Cols> Columns in y.
306   * @param f The function to integrate. It must take two arguments t and y.
307   * @param t The initial value of t.
308   * @param y The initial value of y.
309   * @param dt The time over which to integrate in seconds.
310   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
311   * @return the integration of dx/dt = f(x, u) for dt.
312   */
313  public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rkdp(
314      BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f,
315      double t,
316      Matrix<Rows, Cols> y,
317      double dt,
318      double maxError) {
319    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
320    // Butcher tableau the following arrays came from.
321
322    // final double[6][6]
323    final double[][] A = {
324      {1.0 / 5.0},
325      {3.0 / 40.0, 9.0 / 40.0},
326      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
327      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
328      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
329      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
330    };
331
332    // final double[7]
333    final double[] b1 = {
334      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
335    };
336
337    // final double[7]
338    final double[] b2 = {
339      5179.0 / 57600.0,
340      0.0,
341      7571.0 / 16695.0,
342      393.0 / 640.0,
343      -92097.0 / 339200.0,
344      187.0 / 2100.0,
345      1.0 / 40.0
346    };
347
348    // final double[6]
349    final double[] c = {1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0};
350
351    Matrix<Rows, Cols> newY;
352    double truncationError;
353
354    double dtElapsed = 0.0;
355    double h = dt;
356
357    // Loop until we've gotten to our desired dt
358    while (dtElapsed < dt) {
359      do {
360        // Only allow us to advance up to the dt remaining
361        h = Math.min(h, dt - dtElapsed);
362
363        var k1 = f.apply(t, y);
364        var k2 = f.apply(t + h * c[0], y.plus(k1.times(A[0][0]).times(h)));
365        var k3 = f.apply(t + h * c[1], y.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)));
366        var k4 =
367            f.apply(
368                t + h * c[2],
369                y.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)));
370        var k5 =
371            f.apply(
372                t + h * c[3],
373                y.plus(
374                    k1.times(A[3][0])
375                        .plus(k2.times(A[3][1]))
376                        .plus(k3.times(A[3][2]))
377                        .plus(k4.times(A[3][3]))
378                        .times(h)));
379        var k6 =
380            f.apply(
381                t + h * c[4],
382                y.plus(
383                    k1.times(A[4][0])
384                        .plus(k2.times(A[4][1]))
385                        .plus(k3.times(A[4][2]))
386                        .plus(k4.times(A[4][3]))
387                        .plus(k5.times(A[4][4]))
388                        .times(h)));
389
390        // Since the final row of A and the array b1 have the same coefficients
391        // and k7 has no effect on newY, we can reuse the calculation.
392        newY =
393            y.plus(
394                k1.times(A[5][0])
395                    .plus(k2.times(A[5][1]))
396                    .plus(k3.times(A[5][2]))
397                    .plus(k4.times(A[5][3]))
398                    .plus(k5.times(A[5][4]))
399                    .plus(k6.times(A[5][5]))
400                    .times(h));
401        var k7 = f.apply(t + h * c[5], newY);
402
403        truncationError =
404            (k1.times(b1[0] - b2[0])
405                    .plus(k2.times(b1[1] - b2[1]))
406                    .plus(k3.times(b1[2] - b2[2]))
407                    .plus(k4.times(b1[3] - b2[3]))
408                    .plus(k5.times(b1[4] - b2[4]))
409                    .plus(k6.times(b1[5] - b2[5]))
410                    .plus(k7.times(b1[6] - b2[6]))
411                    .times(h))
412                .normF();
413
414        if (truncationError == 0.0) {
415          h = dt - dtElapsed;
416        } else {
417          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
418        }
419      } while (truncationError > maxError);
420
421      dtElapsed += h;
422      y = newY;
423    }
424
425    return y;
426  }
427}