001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.system; 006 007import java.util.function.BiFunction; 008import java.util.function.DoubleBinaryOperator; 009import java.util.function.DoubleUnaryOperator; 010import java.util.function.UnaryOperator; 011import org.ejml.simple.SimpleMatrix; 012import org.wpilib.math.linalg.Matrix; 013import org.wpilib.math.numbers.N1; 014import org.wpilib.math.util.Num; 015 016/** Numerical integration utilities. */ 017public final class NumericalIntegration { 018 private NumericalIntegration() { 019 // utility Class 020 } 021 022 /** 023 * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 024 * 025 * @param f The function to integrate. It must take one argument x. 026 * @param x The initial value of x. 027 * @param dt The time over which to integrate in seconds. 028 * @return the integration of dx/dt = f(x) for dt. 029 */ 030 public static double rk4(DoubleUnaryOperator f, double x, double dt) { 031 final var h = dt; 032 final var k1 = f.applyAsDouble(x); 033 final var k2 = f.applyAsDouble(x + h * k1 * 0.5); 034 final var k3 = f.applyAsDouble(x + h * k2 * 0.5); 035 final var k4 = f.applyAsDouble(x + h * k3); 036 037 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 038 } 039 040 /** 041 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 042 * 043 * @param f The function to integrate. It must take two arguments x and u. 044 * @param x The initial value of x. 045 * @param u The value u held constant over the integration period. 046 * @param dt The time over which to integrate in seconds. 047 * @return the integration of dx/dt = f(x, u) for dt. 048 */ 049 public static double rk4(DoubleBinaryOperator f, double x, double u, double dt) { 050 final var h = dt; 051 052 final var k1 = f.applyAsDouble(x, u); 053 final var k2 = f.applyAsDouble(x + h * k1 * 0.5, u); 054 final var k3 = f.applyAsDouble(x + h * k2 * 0.5, u); 055 final var k4 = f.applyAsDouble(x + h * k3, u); 056 057 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 058 } 059 060 /** 061 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 062 * 063 * @param f The function to integrate. It must take two arguments x and u. 064 * @param x The initial value of x. 065 * @param u The value u held constant over the integration period. 066 * @param dt The time over which to integrate. 067 * @return the integration of dx/dt = f(x, u) for dt. 068 */ 069 public static SimpleMatrix rk4( 070 BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> f, 071 SimpleMatrix x, 072 SimpleMatrix u, 073 double dt) { 074 var h = dt; 075 076 var k1 = f.apply(x, u); 077 var k2 = f.apply(x.plus(k1.scale(h * 0.5)), u); 078 var k3 = f.apply(x.plus(k2.scale(h * 0.5)), u); 079 var k4 = f.apply(x.plus(k3.scale(h)), u); 080 081 return x.plus(k1.plus(k2.scale(2.0)).plus(k3.scale(2.0)).plus(k4).scale(h / 6.0)); 082 } 083 084 /** 085 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 086 * 087 * @param <States> A Num representing the states of the system to integrate. 088 * @param <Inputs> A Num representing the inputs of the system to integrate. 089 * @param f The function to integrate. It must take two arguments x and u. 090 * @param x The initial value of x. 091 * @param u The value u held constant over the integration period. 092 * @param dt The time over which to integrate in seconds. 093 * @return the integration of dx/dt = f(x, u) for dt. 094 */ 095 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4( 096 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 097 Matrix<States, N1> x, 098 Matrix<Inputs, N1> u, 099 double dt) { 100 final var h = dt; 101 102 Matrix<States, N1> k1 = f.apply(x, u); 103 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u); 104 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u); 105 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u); 106 107 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 108 } 109 110 /** 111 * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 112 * 113 * @param <States> A Num representing the states of the system. 114 * @param f The function to integrate. It must take one argument x. 115 * @param x The initial value of x. 116 * @param dt The time over which to integrate in seconds. 117 * @return the integration of dx/dt = f(x) for dt. 118 */ 119 public static <States extends Num> Matrix<States, N1> rk4( 120 UnaryOperator<Matrix<States, N1>> f, Matrix<States, N1> x, double dt) { 121 final var h = dt; 122 123 Matrix<States, N1> k1 = f.apply(x); 124 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5))); 125 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5))); 126 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h))); 127 128 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 129 } 130 131 /** 132 * Performs 4th order Runge-Kutta integration of dx/dt = f(t, y) for dt. 133 * 134 * @param <Rows> Rows in y. 135 * @param <Cols> Columns in y. 136 * @param f The function to integrate. It must take two arguments t and y. 137 * @param t The initial value of t. 138 * @param y The initial value of y. 139 * @param dt The time over which to integrate in seconds. 140 * @return the integration of dx/dt = f(x) for dt. 141 */ 142 public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rk4( 143 BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f, 144 double t, 145 Matrix<Rows, Cols> y, 146 double dt) { 147 final var h = dt; 148 149 Matrix<Rows, Cols> k1 = f.apply(t, y); 150 Matrix<Rows, Cols> k2 = f.apply(t + dt * 0.5, y.plus(k1.times(h * 0.5))); 151 Matrix<Rows, Cols> k3 = f.apply(t + dt * 0.5, y.plus(k2.times(h * 0.5))); 152 Matrix<Rows, Cols> k4 = f.apply(t + dt, y.plus(k3.times(h))); 153 154 return y.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 155 } 156 157 /** 158 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max 159 * error is 1e-6. 160 * 161 * @param <States> A Num representing the states of the system to integrate. 162 * @param <Inputs> A Num representing the inputs of the system to integrate. 163 * @param f The function to integrate. It must take two arguments x and u. 164 * @param x The initial value of x. 165 * @param u The value u held constant over the integration period. 166 * @param dt The time over which to integrate in seconds. 167 * @return the integration of dx/dt = f(x, u) for dt. 168 */ 169 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 170 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 171 Matrix<States, N1> x, 172 Matrix<Inputs, N1> u, 173 double dt) { 174 return rkdp(f, x, u, dt, 1e-6); 175 } 176 177 /** 178 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. 179 * 180 * @param <States> A Num representing the states of the system to integrate. 181 * @param <Inputs> A Num representing the inputs of the system to integrate. 182 * @param f The function to integrate. It must take two arguments x and u. 183 * @param x The initial value of x. 184 * @param u The value u held constant over the integration period. 185 * @param dt The time over which to integrate in seconds. 186 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 187 * @return the integration of dx/dt = f(x, u) for dt. 188 */ 189 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 190 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 191 Matrix<States, N1> x, 192 Matrix<Inputs, N1> u, 193 double dt, 194 double maxError) { 195 // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the 196 // Butcher tableau the following arrays came from. 197 198 // final double[6][6] 199 final double[][] A = { 200 {1.0 / 5.0}, 201 {3.0 / 40.0, 9.0 / 40.0}, 202 {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0}, 203 {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0}, 204 {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0}, 205 {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0} 206 }; 207 208 // final double[7] 209 final double[] b1 = { 210 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0 211 }; 212 213 // final double[7] 214 final double[] b2 = { 215 5179.0 / 57600.0, 216 0.0, 217 7571.0 / 16695.0, 218 393.0 / 640.0, 219 -92097.0 / 339200.0, 220 187.0 / 2100.0, 221 1.0 / 40.0 222 }; 223 224 Matrix<States, N1> newX; 225 double truncationError; 226 227 double dtElapsed = 0.0; 228 double h = dt; 229 230 // Loop until we've gotten to our desired dt 231 while (dtElapsed < dt) { 232 do { 233 // Only allow us to advance up to the dt remaining 234 h = Math.min(h, dt - dtElapsed); 235 236 var k1 = f.apply(x, u); 237 var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u); 238 var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u); 239 var k4 = 240 f.apply( 241 x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)), 242 u); 243 var k5 = 244 f.apply( 245 x.plus( 246 k1.times(A[3][0]) 247 .plus(k2.times(A[3][1])) 248 .plus(k3.times(A[3][2])) 249 .plus(k4.times(A[3][3])) 250 .times(h)), 251 u); 252 var k6 = 253 f.apply( 254 x.plus( 255 k1.times(A[4][0]) 256 .plus(k2.times(A[4][1])) 257 .plus(k3.times(A[4][2])) 258 .plus(k4.times(A[4][3])) 259 .plus(k5.times(A[4][4])) 260 .times(h)), 261 u); 262 263 // Since the final row of A and the array b1 have the same coefficients 264 // and k7 has no effect on newX, we can reuse the calculation. 265 newX = 266 x.plus( 267 k1.times(A[5][0]) 268 .plus(k2.times(A[5][1])) 269 .plus(k3.times(A[5][2])) 270 .plus(k4.times(A[5][3])) 271 .plus(k5.times(A[5][4])) 272 .plus(k6.times(A[5][5])) 273 .times(h)); 274 var k7 = f.apply(newX, u); 275 276 truncationError = 277 (k1.times(b1[0] - b2[0]) 278 .plus(k2.times(b1[1] - b2[1])) 279 .plus(k3.times(b1[2] - b2[2])) 280 .plus(k4.times(b1[3] - b2[3])) 281 .plus(k5.times(b1[4] - b2[4])) 282 .plus(k6.times(b1[5] - b2[5])) 283 .plus(k7.times(b1[6] - b2[6])) 284 .times(h)) 285 .normF(); 286 287 if (truncationError == 0.0) { 288 h = dt - dtElapsed; 289 } else { 290 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 291 } 292 } while (truncationError > maxError); 293 294 dtElapsed += h; 295 x = newX; 296 } 297 298 return x; 299 } 300 301 /** 302 * Performs adaptive Dormand-Prince integration of dx/dt = f(t, y) for dt. 303 * 304 * @param <Rows> Rows in y. 305 * @param <Cols> Columns in y. 306 * @param f The function to integrate. It must take two arguments t and y. 307 * @param t The initial value of t. 308 * @param y The initial value of y. 309 * @param dt The time over which to integrate in seconds. 310 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 311 * @return the integration of dx/dt = f(x, u) for dt. 312 */ 313 public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rkdp( 314 BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f, 315 double t, 316 Matrix<Rows, Cols> y, 317 double dt, 318 double maxError) { 319 // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the 320 // Butcher tableau the following arrays came from. 321 322 // final double[6][6] 323 final double[][] A = { 324 {1.0 / 5.0}, 325 {3.0 / 40.0, 9.0 / 40.0}, 326 {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0}, 327 {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0}, 328 {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0}, 329 {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0} 330 }; 331 332 // final double[7] 333 final double[] b1 = { 334 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0 335 }; 336 337 // final double[7] 338 final double[] b2 = { 339 5179.0 / 57600.0, 340 0.0, 341 7571.0 / 16695.0, 342 393.0 / 640.0, 343 -92097.0 / 339200.0, 344 187.0 / 2100.0, 345 1.0 / 40.0 346 }; 347 348 // final double[6] 349 final double[] c = {1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0}; 350 351 Matrix<Rows, Cols> newY; 352 double truncationError; 353 354 double dtElapsed = 0.0; 355 double h = dt; 356 357 // Loop until we've gotten to our desired dt 358 while (dtElapsed < dt) { 359 do { 360 // Only allow us to advance up to the dt remaining 361 h = Math.min(h, dt - dtElapsed); 362 363 var k1 = f.apply(t, y); 364 var k2 = f.apply(t + h * c[0], y.plus(k1.times(A[0][0]).times(h))); 365 var k3 = f.apply(t + h * c[1], y.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h))); 366 var k4 = 367 f.apply( 368 t + h * c[2], 369 y.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h))); 370 var k5 = 371 f.apply( 372 t + h * c[3], 373 y.plus( 374 k1.times(A[3][0]) 375 .plus(k2.times(A[3][1])) 376 .plus(k3.times(A[3][2])) 377 .plus(k4.times(A[3][3])) 378 .times(h))); 379 var k6 = 380 f.apply( 381 t + h * c[4], 382 y.plus( 383 k1.times(A[4][0]) 384 .plus(k2.times(A[4][1])) 385 .plus(k3.times(A[4][2])) 386 .plus(k4.times(A[4][3])) 387 .plus(k5.times(A[4][4])) 388 .times(h))); 389 390 // Since the final row of A and the array b1 have the same coefficients 391 // and k7 has no effect on newY, we can reuse the calculation. 392 newY = 393 y.plus( 394 k1.times(A[5][0]) 395 .plus(k2.times(A[5][1])) 396 .plus(k3.times(A[5][2])) 397 .plus(k4.times(A[5][3])) 398 .plus(k5.times(A[5][4])) 399 .plus(k6.times(A[5][5])) 400 .times(h)); 401 var k7 = f.apply(t + h * c[5], newY); 402 403 truncationError = 404 (k1.times(b1[0] - b2[0]) 405 .plus(k2.times(b1[1] - b2[1])) 406 .plus(k3.times(b1[2] - b2[2])) 407 .plus(k4.times(b1[3] - b2[3])) 408 .plus(k5.times(b1[4] - b2[4])) 409 .plus(k6.times(b1[5] - b2[5])) 410 .plus(k7.times(b1[6] - b2[6])) 411 .times(h)) 412 .normF(); 413 414 if (truncationError == 0.0) { 415 h = dt - dtElapsed; 416 } else { 417 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 418 } 419 } while (truncationError > maxError); 420 421 dtElapsed += h; 422 y = newY; 423 } 424 425 return y; 426 } 427}