001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Num;
009import edu.wpi.first.math.numbers.N1;
010import java.util.function.BiFunction;
011import java.util.function.DoubleFunction;
012import java.util.function.Function;
013
014/** Numerical integration utilities. */
015public final class NumericalIntegration {
016  private NumericalIntegration() {
017    // utility Class
018  }
019
020  /**
021   * Performs Runge Kutta integration (4th order).
022   *
023   * @param f The function to integrate, which takes one argument x.
024   * @param x The initial value of x.
025   * @param dtSeconds The time over which to integrate.
026   * @return the integration of dx/dt = f(x) for dt.
027   */
028  @SuppressWarnings("overloads")
029  public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) {
030    final var h = dtSeconds;
031    final var k1 = f.apply(x);
032    final var k2 = f.apply(x + h * k1 * 0.5);
033    final var k3 = f.apply(x + h * k2 * 0.5);
034    final var k4 = f.apply(x + h * k3);
035
036    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
037  }
038
039  /**
040   * Performs Runge Kutta integration (4th order).
041   *
042   * @param f The function to integrate. It must take two arguments x and u.
043   * @param x The initial value of x.
044   * @param u The value u held constant over the integration period.
045   * @param dtSeconds The time over which to integrate.
046   * @return The result of Runge Kutta integration (4th order).
047   */
048  @SuppressWarnings("overloads")
049  public static double rk4(
050      BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) {
051    final var h = dtSeconds;
052
053    final var k1 = f.apply(x, u);
054    final var k2 = f.apply(x + h * k1 * 0.5, u);
055    final var k3 = f.apply(x + h * k2 * 0.5, u);
056    final var k4 = f.apply(x + h * k3, u);
057
058    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
059  }
060
061  /**
062   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
063   *
064   * @param <States> A Num representing the states of the system to integrate.
065   * @param <Inputs> A Num representing the inputs of the system to integrate.
066   * @param f The function to integrate. It must take two arguments x and u.
067   * @param x The initial value of x.
068   * @param u The value u held constant over the integration period.
069   * @param dtSeconds The time over which to integrate.
070   * @return the integration of dx/dt = f(x, u) for dt.
071   */
072  @SuppressWarnings("overloads")
073  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
074      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
075      Matrix<States, N1> x,
076      Matrix<Inputs, N1> u,
077      double dtSeconds) {
078    final var h = dtSeconds;
079
080    Matrix<States, N1> k1 = f.apply(x, u);
081    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
082    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
083    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
084
085    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
086  }
087
088  /**
089   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
090   *
091   * @param <States> A Num prepresenting the states of the system.
092   * @param f The function to integrate. It must take one argument x.
093   * @param x The initial value of x.
094   * @param dtSeconds The time over which to integrate.
095   * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
096   */
097  @SuppressWarnings("overloads")
098  public static <States extends Num> Matrix<States, N1> rk4(
099      Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) {
100    final var h = dtSeconds;
101
102    Matrix<States, N1> k1 = f.apply(x);
103    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
104    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
105    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
106
107    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
108  }
109
110  /**
111   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
112   * error is 1e-6.
113   *
114   * @param <States> A Num representing the states of the system to integrate.
115   * @param <Inputs> A Num representing the inputs of the system to integrate.
116   * @param f The function to integrate. It must take two arguments x and u.
117   * @param x The initial value of x.
118   * @param u The value u held constant over the integration period.
119   * @param dtSeconds The time over which to integrate.
120   * @return the integration of dx/dt = f(x, u) for dt.
121   */
122  @SuppressWarnings("overloads")
123  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
124      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
125      Matrix<States, N1> x,
126      Matrix<Inputs, N1> u,
127      double dtSeconds) {
128    return rkdp(f, x, u, dtSeconds, 1e-6);
129  }
130
131  /**
132   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
133   *
134   * @param <States> A Num representing the states of the system to integrate.
135   * @param <Inputs> A Num representing the inputs of the system to integrate.
136   * @param f The function to integrate. It must take two arguments x and u.
137   * @param x The initial value of x.
138   * @param u The value u held constant over the integration period.
139   * @param dtSeconds The time over which to integrate.
140   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
141   * @return the integration of dx/dt = f(x, u) for dt.
142   */
143  @SuppressWarnings("overloads")
144  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
145      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
146      Matrix<States, N1> x,
147      Matrix<Inputs, N1> u,
148      double dtSeconds,
149      double maxError) {
150    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
151    // Butcher tableau the following arrays came from.
152
153    // final double[6][6]
154    final double[][] A = {
155      {1.0 / 5.0},
156      {3.0 / 40.0, 9.0 / 40.0},
157      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
158      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
159      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
160      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
161    };
162
163    // final double[7]
164    final double[] b1 = {
165      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
166    };
167
168    // final double[7]
169    final double[] b2 = {
170      5179.0 / 57600.0,
171      0.0,
172      7571.0 / 16695.0,
173      393.0 / 640.0,
174      -92097.0 / 339200.0,
175      187.0 / 2100.0,
176      1.0 / 40.0
177    };
178
179    Matrix<States, N1> newX;
180    double truncationError;
181
182    double dtElapsed = 0.0;
183    double h = dtSeconds;
184
185    // Loop until we've gotten to our desired dt
186    while (dtElapsed < dtSeconds) {
187      do {
188        // Only allow us to advance up to the dt remaining
189        h = Math.min(h, dtSeconds - dtElapsed);
190
191        var k1 = f.apply(x, u);
192        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
193        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
194        var k4 =
195            f.apply(
196                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
197                u);
198        var k5 =
199            f.apply(
200                x.plus(
201                    k1.times(A[3][0])
202                        .plus(k2.times(A[3][1]))
203                        .plus(k3.times(A[3][2]))
204                        .plus(k4.times(A[3][3]))
205                        .times(h)),
206                u);
207        var k6 =
208            f.apply(
209                x.plus(
210                    k1.times(A[4][0])
211                        .plus(k2.times(A[4][1]))
212                        .plus(k3.times(A[4][2]))
213                        .plus(k4.times(A[4][3]))
214                        .plus(k5.times(A[4][4]))
215                        .times(h)),
216                u);
217
218        // Since the final row of A and the array b1 have the same coefficients
219        // and k7 has no effect on newX, we can reuse the calculation.
220        newX =
221            x.plus(
222                k1.times(A[5][0])
223                    .plus(k2.times(A[5][1]))
224                    .plus(k3.times(A[5][2]))
225                    .plus(k4.times(A[5][3]))
226                    .plus(k5.times(A[5][4]))
227                    .plus(k6.times(A[5][5]))
228                    .times(h));
229        var k7 = f.apply(newX, u);
230
231        truncationError =
232            (k1.times(b1[0] - b2[0])
233                    .plus(k2.times(b1[1] - b2[1]))
234                    .plus(k3.times(b1[2] - b2[2]))
235                    .plus(k4.times(b1[3] - b2[3]))
236                    .plus(k5.times(b1[4] - b2[4]))
237                    .plus(k6.times(b1[5] - b2[5]))
238                    .plus(k7.times(b1[6] - b2[6]))
239                    .times(h))
240                .normF();
241
242        if (truncationError == 0.0) {
243          h = dtSeconds - dtElapsed;
244        } else {
245          h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
246        }
247      } while (truncationError > maxError);
248
249      dtElapsed += h;
250      x = newX;
251    }
252
253    return x;
254  }
255}