001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.system; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Num; 009import edu.wpi.first.math.numbers.N1; 010import java.util.function.BiFunction; 011import java.util.function.DoubleFunction; 012import java.util.function.Function; 013 014/** Numerical integration utilities. */ 015public final class NumericalIntegration { 016 private NumericalIntegration() { 017 // utility Class 018 } 019 020 /** 021 * Performs Runge Kutta integration (4th order). 022 * 023 * @param f The function to integrate, which takes one argument x. 024 * @param x The initial value of x. 025 * @param dtSeconds The time over which to integrate. 026 * @return the integration of dx/dt = f(x) for dt. 027 */ 028 @SuppressWarnings("overloads") 029 public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) { 030 final var h = dtSeconds; 031 final var k1 = f.apply(x); 032 final var k2 = f.apply(x + h * k1 * 0.5); 033 final var k3 = f.apply(x + h * k2 * 0.5); 034 final var k4 = f.apply(x + h * k3); 035 036 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 037 } 038 039 /** 040 * Performs Runge Kutta integration (4th order). 041 * 042 * @param f The function to integrate. It must take two arguments x and u. 043 * @param x The initial value of x. 044 * @param u The value u held constant over the integration period. 045 * @param dtSeconds The time over which to integrate. 046 * @return The result of Runge Kutta integration (4th order). 047 */ 048 @SuppressWarnings("overloads") 049 public static double rk4( 050 BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) { 051 final var h = dtSeconds; 052 053 final var k1 = f.apply(x, u); 054 final var k2 = f.apply(x + h * k1 * 0.5, u); 055 final var k3 = f.apply(x + h * k2 * 0.5, u); 056 final var k4 = f.apply(x + h * k3, u); 057 058 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 059 } 060 061 /** 062 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 063 * 064 * @param <States> A Num representing the states of the system to integrate. 065 * @param <Inputs> A Num representing the inputs of the system to integrate. 066 * @param f The function to integrate. It must take two arguments x and u. 067 * @param x The initial value of x. 068 * @param u The value u held constant over the integration period. 069 * @param dtSeconds The time over which to integrate. 070 * @return the integration of dx/dt = f(x, u) for dt. 071 */ 072 @SuppressWarnings("overloads") 073 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4( 074 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 075 Matrix<States, N1> x, 076 Matrix<Inputs, N1> u, 077 double dtSeconds) { 078 final var h = dtSeconds; 079 080 Matrix<States, N1> k1 = f.apply(x, u); 081 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u); 082 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u); 083 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u); 084 085 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 086 } 087 088 /** 089 * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 090 * 091 * @param <States> A Num prepresenting the states of the system. 092 * @param f The function to integrate. It must take one argument x. 093 * @param x The initial value of x. 094 * @param dtSeconds The time over which to integrate. 095 * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 096 */ 097 @SuppressWarnings("overloads") 098 public static <States extends Num> Matrix<States, N1> rk4( 099 Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) { 100 final var h = dtSeconds; 101 102 Matrix<States, N1> k1 = f.apply(x); 103 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5))); 104 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5))); 105 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h))); 106 107 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 108 } 109 110 /** 111 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max 112 * error is 1e-6. 113 * 114 * @param <States> A Num representing the states of the system to integrate. 115 * @param <Inputs> A Num representing the inputs of the system to integrate. 116 * @param f The function to integrate. It must take two arguments x and u. 117 * @param x The initial value of x. 118 * @param u The value u held constant over the integration period. 119 * @param dtSeconds The time over which to integrate. 120 * @return the integration of dx/dt = f(x, u) for dt. 121 */ 122 @SuppressWarnings("overloads") 123 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 124 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 125 Matrix<States, N1> x, 126 Matrix<Inputs, N1> u, 127 double dtSeconds) { 128 return rkdp(f, x, u, dtSeconds, 1e-6); 129 } 130 131 /** 132 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. 133 * 134 * @param <States> A Num representing the states of the system to integrate. 135 * @param <Inputs> A Num representing the inputs of the system to integrate. 136 * @param f The function to integrate. It must take two arguments x and u. 137 * @param x The initial value of x. 138 * @param u The value u held constant over the integration period. 139 * @param dtSeconds The time over which to integrate. 140 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 141 * @return the integration of dx/dt = f(x, u) for dt. 142 */ 143 @SuppressWarnings("overloads") 144 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 145 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 146 Matrix<States, N1> x, 147 Matrix<Inputs, N1> u, 148 double dtSeconds, 149 double maxError) { 150 // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the 151 // Butcher tableau the following arrays came from. 152 153 // final double[6][6] 154 final double[][] A = { 155 {1.0 / 5.0}, 156 {3.0 / 40.0, 9.0 / 40.0}, 157 {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0}, 158 {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0}, 159 {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0}, 160 {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0} 161 }; 162 163 // final double[7] 164 final double[] b1 = { 165 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0 166 }; 167 168 // final double[7] 169 final double[] b2 = { 170 5179.0 / 57600.0, 171 0.0, 172 7571.0 / 16695.0, 173 393.0 / 640.0, 174 -92097.0 / 339200.0, 175 187.0 / 2100.0, 176 1.0 / 40.0 177 }; 178 179 Matrix<States, N1> newX; 180 double truncationError; 181 182 double dtElapsed = 0.0; 183 double h = dtSeconds; 184 185 // Loop until we've gotten to our desired dt 186 while (dtElapsed < dtSeconds) { 187 do { 188 // Only allow us to advance up to the dt remaining 189 h = Math.min(h, dtSeconds - dtElapsed); 190 191 var k1 = f.apply(x, u); 192 var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u); 193 var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u); 194 var k4 = 195 f.apply( 196 x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)), 197 u); 198 var k5 = 199 f.apply( 200 x.plus( 201 k1.times(A[3][0]) 202 .plus(k2.times(A[3][1])) 203 .plus(k3.times(A[3][2])) 204 .plus(k4.times(A[3][3])) 205 .times(h)), 206 u); 207 var k6 = 208 f.apply( 209 x.plus( 210 k1.times(A[4][0]) 211 .plus(k2.times(A[4][1])) 212 .plus(k3.times(A[4][2])) 213 .plus(k4.times(A[4][3])) 214 .plus(k5.times(A[4][4])) 215 .times(h)), 216 u); 217 218 // Since the final row of A and the array b1 have the same coefficients 219 // and k7 has no effect on newX, we can reuse the calculation. 220 newX = 221 x.plus( 222 k1.times(A[5][0]) 223 .plus(k2.times(A[5][1])) 224 .plus(k3.times(A[5][2])) 225 .plus(k4.times(A[5][3])) 226 .plus(k5.times(A[5][4])) 227 .plus(k6.times(A[5][5])) 228 .times(h)); 229 var k7 = f.apply(newX, u); 230 231 truncationError = 232 (k1.times(b1[0] - b2[0]) 233 .plus(k2.times(b1[1] - b2[1])) 234 .plus(k3.times(b1[2] - b2[2])) 235 .plus(k4.times(b1[3] - b2[3])) 236 .plus(k5.times(b1[4] - b2[4])) 237 .plus(k6.times(b1[5] - b2[5])) 238 .plus(k7.times(b1[6] - b2[6])) 239 .times(h)) 240 .normF(); 241 242 if (truncationError == 0.0) { 243 h = dtSeconds - dtElapsed; 244 } else { 245 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 246 } 247 } while (truncationError > maxError); 248 249 dtElapsed += h; 250 x = newX; 251 } 252 253 return x; 254 } 255}