001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.optimization; 006 007import static org.wpilib.math.optimization.Constraints.eq; 008import static org.wpilib.math.optimization.Constraints.ge; 009import static org.wpilib.math.optimization.Constraints.le; 010 011import java.util.function.BiConsumer; 012import java.util.function.BiFunction; 013import org.ejml.simple.SimpleMatrix; 014import org.wpilib.math.autodiff.Variable; 015import org.wpilib.math.autodiff.VariableBlock; 016import org.wpilib.math.autodiff.VariableMatrix; 017import org.wpilib.math.optimization.ocp.ConstraintEvaluationFunction; 018import org.wpilib.math.optimization.ocp.DynamicsFunction; 019import org.wpilib.math.optimization.ocp.DynamicsType; 020import org.wpilib.math.optimization.ocp.TimestepMethod; 021import org.wpilib.math.optimization.ocp.TranscriptionMethod; 022 023/** 024 * This class allows the user to pose and solve a constrained optimal control problem (OCP) in a 025 * variety of ways. 026 * 027 * <p>The system is transcripted by one of three methods (direct transcription, direct collocation, 028 * or single-shooting) and additional constraints can be added. 029 * 030 * <p>In direct transcription, each state is a decision variable constrained to the integrated 031 * dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of 032 * cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend 033 * explicitly as a function of all previous states and all previous inputs. 034 * 035 * <p>Explicit ODEs are integrated using RK4. 036 * 037 * <p>For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state 038 * transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ). 039 * 040 * <p>Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use 041 * either an ODE or state transition function. 042 * 043 * <p>https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method. 044 */ 045public class OCP extends Problem { 046 private int m_numSteps; 047 048 private DynamicsFunction m_dynamics; 049 private DynamicsType m_dynamicsType; 050 051 private VariableMatrix m_X; 052 private VariableMatrix m_U; 053 private VariableMatrix m_DT; 054 055 /** 056 * Builds an optimization problem using a system evolution function (explicit ODE or discrete 057 * state transition function). 058 * 059 * @param numStates The number of system states. 060 * @param numInputs The number of system inputs. 061 * @param dt The timestep for fixed-step integration. 062 * @param numSteps The number of control points. 063 * @param dynamics Function representing an explicit or implicit ODE, or a discrete state 064 * transition function. 065 * <ul> 066 * <li>Explicit: dx/dt = f(x, u, *) 067 * <li>Implicit: f([x dx/dt]', u, *) = 0 068 * <li>State transition: xₖ₊₁ = f(xₖ, uₖ) 069 * </ul> 070 * 071 * @param dynamicsType The type of system evolution function. 072 * @param timestepMethod The timestep method. 073 * @param transcriptionMethod The transcription method. 074 */ 075 public OCP( 076 int numStates, 077 int numInputs, 078 double dt, 079 int numSteps, 080 BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> dynamics, 081 DynamicsType dynamicsType, 082 TimestepMethod timestepMethod, 083 TranscriptionMethod transcriptionMethod) { 084 this( 085 numStates, 086 numInputs, 087 dt, 088 numSteps, 089 (Variable t, VariableMatrix x, VariableMatrix u, Variable _dt) -> dynamics.apply(x, u), 090 dynamicsType, 091 timestepMethod, 092 transcriptionMethod); 093 } 094 095 /** 096 * Builds an optimization problem using a system evolution function (explicit ODE or discrete 097 * state transition function). 098 * 099 * @param numStates The number of system states. 100 * @param numInputs The number of system inputs. 101 * @param dt The timestep for fixed-step integration. 102 * @param numSteps The number of control points. 103 * @param dynamics Function representing an explicit or implicit ODE, or a discrete state 104 * transition function. 105 * <ul> 106 * <li>Explicit: dx/dt = f(t, x, u, *) 107 * <li>Implicit: f(t, [x dx/dt]', u, *) = 0 108 * <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt) 109 * </ul> 110 * 111 * @param dynamicsType The type of system evolution function. 112 * @param timestepMethod The timestep method. 113 * @param transcriptionMethod The transcription method. 114 */ 115 @SuppressWarnings("this-escape") 116 public OCP( 117 int numStates, 118 int numInputs, 119 double dt, 120 int numSteps, 121 DynamicsFunction dynamics, 122 DynamicsType dynamicsType, 123 TimestepMethod timestepMethod, 124 TranscriptionMethod transcriptionMethod) { 125 m_numSteps = numSteps; 126 m_dynamics = dynamics; 127 m_dynamicsType = dynamicsType; 128 129 // u is numSteps + 1 so that the final constraint function evaluation works 130 m_U = decisionVariable(numInputs, m_numSteps + 1); 131 132 if (timestepMethod == TimestepMethod.FIXED) { 133 m_DT = new VariableMatrix(1, m_numSteps + 1); 134 for (int i = 0; i < numSteps + 1; ++i) { 135 m_DT.set(0, i, dt); 136 } 137 } else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) { 138 Variable single_dt = decisionVariable(); 139 single_dt.setValue(dt); 140 141 // Set the member variable matrix to track the decision variable 142 m_DT = new VariableMatrix(1, m_numSteps + 1); 143 for (int i = 0; i < numSteps + 1; ++i) { 144 m_DT.set(0, i, single_dt); 145 } 146 } else if (timestepMethod == TimestepMethod.VARIABLE) { 147 m_DT = decisionVariable(1, m_numSteps + 1); 148 for (int i = 0; i < numSteps + 1; ++i) { 149 m_DT.get(0, i).setValue(dt); 150 } 151 } 152 153 if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) { 154 m_X = decisionVariable(numStates, m_numSteps + 1); 155 constrainDirectTranscription(); 156 } else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) { 157 m_X = decisionVariable(numStates, m_numSteps + 1); 158 constrainDirectCollocation(); 159 } else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) { 160 // In single-shooting the states aren't decision variables, but instead 161 // depend on the input and previous states 162 m_X = new VariableMatrix(numStates, m_numSteps + 1); 163 constrainSingleShooting(); 164 } 165 } 166 167 /** 168 * Constrains the initial state. 169 * 170 * @param initialState the initial state to constrain to. 171 */ 172 public void constrainInitialState(double initialState) { 173 subjectTo(eq(this.initialState(), initialState)); 174 } 175 176 /** 177 * Constrains the initial state. 178 * 179 * @param initialState the initial state to constrain to. 180 */ 181 public void constrainInitialState(Variable initialState) { 182 subjectTo(eq(this.initialState(), initialState)); 183 } 184 185 /** 186 * Constrains the initial state. 187 * 188 * @param initialState the initial state to constrain to. 189 */ 190 public void constrainInitialState(SimpleMatrix initialState) { 191 subjectTo(eq(this.initialState(), initialState)); 192 } 193 194 /** 195 * Constrains the initial state. 196 * 197 * @param initialState the initial state to constrain to. 198 */ 199 public void constrainInitialState(VariableMatrix initialState) { 200 subjectTo(eq(this.initialState(), initialState)); 201 } 202 203 /** 204 * Constrains the initial state. 205 * 206 * @param initialState the initial state to constrain to. 207 */ 208 public void constrainInitialState(VariableBlock initialState) { 209 subjectTo(eq(this.initialState(), initialState)); 210 } 211 212 /** 213 * Constrains the final state. 214 * 215 * @param finalState the final state to constrain to. 216 */ 217 public void constrainFinalState(double finalState) { 218 subjectTo(eq(this.finalState(), finalState)); 219 } 220 221 /** 222 * Constrains the final state. 223 * 224 * @param finalState the final state to constrain to. 225 */ 226 public void constrainFinalState(Variable finalState) { 227 subjectTo(eq(this.finalState(), finalState)); 228 } 229 230 /** 231 * Constrains the final state. 232 * 233 * @param finalState the final state to constrain to. 234 */ 235 public void constrainFinalState(SimpleMatrix finalState) { 236 subjectTo(eq(this.finalState(), finalState)); 237 } 238 239 /** 240 * Constrains the final state. 241 * 242 * @param finalState the final state to constrain to. 243 */ 244 public void constrainFinalState(VariableMatrix finalState) { 245 subjectTo(eq(this.finalState(), finalState)); 246 } 247 248 /** 249 * Constrains the final state. 250 * 251 * @param finalState the final state to constrain to. 252 */ 253 public void constrainFinalState(VariableBlock finalState) { 254 subjectTo(eq(this.finalState(), finalState)); 255 } 256 257 /** 258 * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the 259 * corresponding state and input VariableMatrices. 260 * 261 * @param callback The callback f(x, u) where x is the state and u is the input vector. 262 */ 263 public void forEachStep(BiConsumer<VariableMatrix, VariableMatrix> callback) { 264 for (int i = 0; i < m_numSteps + 1; ++i) { 265 var x = X().col(i); 266 var u = U().col(i); 267 callback.accept(new VariableMatrix(x), new VariableMatrix(u)); 268 } 269 } 270 271 /** 272 * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the 273 * corresponding state and input VariableMatrices. 274 * 275 * @param callback The callback f(t, x, u, dt) where t is time, x is the state vector, u is the 276 * input vector, and dt is the timestep duration. 277 */ 278 public void forEachStep(ConstraintEvaluationFunction callback) { 279 var time = new Variable(0.0); 280 281 for (int i = 0; i < m_numSteps + 1; ++i) { 282 var x = X().col(i); 283 var u = U().col(i); 284 var dt = this.dt().get(0, i); 285 callback.accept(time, new VariableMatrix(x), new VariableMatrix(u), dt); 286 287 time = time.plus(dt); 288 } 289 } 290 291 /** 292 * Sets a lower bound on the input. 293 * 294 * @param lowerBound The lower bound that inputs must always be above. Must be shaped 295 * (numInputs)x1. 296 */ 297 public void setLowerInputBound(double lowerBound) { 298 for (int i = 0; i < m_numSteps + 1; ++i) { 299 subjectTo(ge(U().col(i), lowerBound)); 300 } 301 } 302 303 /** 304 * Sets a lower bound on the input. 305 * 306 * @param lowerBound The lower bound that inputs must always be above. Must be shaped 307 * (numInputs)x1. 308 */ 309 public void setLowerInputBound(Variable lowerBound) { 310 for (int i = 0; i < m_numSteps + 1; ++i) { 311 subjectTo(ge(U().col(i), lowerBound)); 312 } 313 } 314 315 /** 316 * Sets a lower bound on the input. 317 * 318 * @param lowerBound The lower bound that inputs must always be above. Must be shaped 319 * (numInputs)x1. 320 */ 321 public void setLowerInputBound(SimpleMatrix lowerBound) { 322 for (int i = 0; i < m_numSteps + 1; ++i) { 323 subjectTo(ge(U().col(i), lowerBound)); 324 } 325 } 326 327 /** 328 * Sets a lower bound on the input. 329 * 330 * @param lowerBound The lower bound that inputs must always be above. Must be shaped 331 * (numInputs)x1. 332 */ 333 public void setLowerInputBound(VariableMatrix lowerBound) { 334 for (int i = 0; i < m_numSteps + 1; ++i) { 335 subjectTo(ge(U().col(i), lowerBound)); 336 } 337 } 338 339 /** 340 * Sets a lower bound on the input. 341 * 342 * @param lowerBound The lower bound that inputs must always be above. Must be shaped 343 * (numInputs)x1. 344 */ 345 public void setLowerInputBound(VariableBlock lowerBound) { 346 for (int i = 0; i < m_numSteps + 1; ++i) { 347 subjectTo(ge(U().col(i), lowerBound)); 348 } 349 } 350 351 /** 352 * Sets an upper bound on the input. 353 * 354 * @param upperBound The upper bound that inputs must always be below. Must be shaped 355 * (numInputs)x1. 356 */ 357 public void setUpperInputBound(double upperBound) { 358 for (int i = 0; i < m_numSteps + 1; ++i) { 359 subjectTo(le(U().col(i), upperBound)); 360 } 361 } 362 363 /** 364 * Sets an upper bound on the input. 365 * 366 * @param upperBound The upper bound that inputs must always be below. Must be shaped 367 * (numInputs)x1. 368 */ 369 public void setUpperInputBound(Variable upperBound) { 370 for (int i = 0; i < m_numSteps + 1; ++i) { 371 subjectTo(le(U().col(i), upperBound)); 372 } 373 } 374 375 /** 376 * Sets an upper bound on the input. 377 * 378 * @param upperBound The upper bound that inputs must always be below. Must be shaped 379 * (numInputs)x1. 380 */ 381 public void setUpperInputBound(SimpleMatrix upperBound) { 382 for (int i = 0; i < m_numSteps + 1; ++i) { 383 subjectTo(le(U().col(i), upperBound)); 384 } 385 } 386 387 /** 388 * Sets an upper bound on the input. 389 * 390 * @param upperBound The upper bound that inputs must always be below. Must be shaped 391 * (numInputs)x1. 392 */ 393 public void setUpperInputBound(VariableMatrix upperBound) { 394 for (int i = 0; i < m_numSteps + 1; ++i) { 395 subjectTo(le(U().col(i), upperBound)); 396 } 397 } 398 399 /** 400 * Sets an upper bound on the input. 401 * 402 * @param upperBound The upper bound that inputs must always be below. Must be shaped 403 * (numInputs)x1. 404 */ 405 public void setUpperInputBound(VariableBlock upperBound) { 406 for (int i = 0; i < m_numSteps + 1; ++i) { 407 subjectTo(le(U().col(i), upperBound)); 408 } 409 } 410 411 /** 412 * Sets a lower bound on the timestep. 413 * 414 * @param minTimestep The minimum timestep in seconds. 415 */ 416 public void setMinTimestep(double minTimestep) { 417 subjectTo(ge(dt(), minTimestep)); 418 } 419 420 /** 421 * Sets an upper bound on the timestep. 422 * 423 * @param maxTimestep The maximum timestep in seconds. 424 */ 425 public void setMaxTimestep(double maxTimestep) { 426 subjectTo(le(dt(), maxTimestep)); 427 } 428 429 /** 430 * Gets the state variables. After the problem is solved, this will contain the optimized 431 * trajectory. 432 * 433 * <p>Shaped (numStates)x(numSteps+1). 434 * 435 * @return The state variable matrix. 436 */ 437 public VariableMatrix X() { 438 return m_X; 439 } 440 441 /** 442 * Gets the input variables. After the problem is solved, this will contain the inputs 443 * corresponding to the optimized trajectory. 444 * 445 * <p>Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory. 446 * 447 * @return The input variable matrix. 448 */ 449 public VariableMatrix U() { 450 return m_U; 451 } 452 453 /** 454 * Gets the timestep variables. After the problem is solved, this will contain the timesteps 455 * corresponding to the optimized trajectory. 456 * 457 * <p>Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory. 458 * 459 * @return The timestep variable matrix. 460 */ 461 public VariableMatrix dt() { 462 return m_DT; 463 } 464 465 /** 466 * Gets the initial state in the trajectory. 467 * 468 * @return The initial state of the trajectory. 469 */ 470 public VariableMatrix initialState() { 471 return new VariableMatrix(m_X.col(0)); 472 } 473 474 /** 475 * Gets the final state in the trajectory. 476 * 477 * @return The final state of the trajectory. 478 */ 479 public VariableMatrix finalState() { 480 return new VariableMatrix(m_X.col(m_numSteps)); 481 } 482 483 /** 484 * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt. 485 * 486 * @param f The function to integrate. It must take two arguments x and u. 487 * @param x The initial value of x. 488 * @param u The value u held constant over the integration period. 489 * @param t0 The initial time. 490 * @param dt The time over which to integrate. 491 */ 492 private static VariableMatrix rk4( 493 DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) { 494 var halfdt = dt.times(0.5); 495 VariableMatrix k1 = f.apply(t0, x, u, dt); 496 VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt); 497 VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt); 498 VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt); 499 500 return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0))); 501 } 502 503 /** Applies direct collocation dynamics constraints. */ 504 private void constrainDirectCollocation() { 505 assert m_dynamicsType == DynamicsType.EXPLICIT_ODE; 506 507 var time = new Variable(0.0); 508 509 // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/ 510 for (int i = 0; i < m_numSteps; ++i) { 511 Variable h = dt().get(0, i); 512 513 var f = m_dynamics; 514 515 var t_begin = time; 516 var t_end = t_begin.plus(h); 517 518 var x_begin = X().col(i); 519 var x_end = X().col(i + 1); 520 521 var u_begin = U().col(i); 522 var u_end = U().col(i + 1); 523 524 var xdot_begin = 525 f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h); 526 var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h); 527 var xdot_c = 528 x_begin 529 .minus(x_end) 530 .times(new Variable(-3).div(h.times(2))) 531 .minus(xdot_begin.plus(xdot_end).times(0.25)); 532 533 var t_c = t_begin.plus(h.times(0.5)); 534 var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8))); 535 var u_c = u_begin.plus(u_end).times(0.5); 536 537 subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h))); 538 539 time = time.plus(h); 540 } 541 } 542 543 /** Applies direct transcription dynamics constraints. */ 544 private void constrainDirectTranscription() { 545 var time = new Variable(0.0); 546 547 for (int i = 0; i < m_numSteps; ++i) { 548 var x_begin = X().col(i); 549 var x_end = X().col(i + 1); 550 var u = U().col(i); 551 Variable dt = this.dt().get(0, i); 552 553 if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) { 554 subjectTo( 555 eq( 556 x_end, 557 rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt))); 558 } else if (m_dynamicsType == DynamicsType.DISCRETE) { 559 subjectTo( 560 eq( 561 x_end, 562 m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt))); 563 } 564 565 time = time.plus(dt); 566 } 567 } 568 569 /** Applies single shooting dynamics constraints. */ 570 private void constrainSingleShooting() { 571 var time = new Variable(0.0); 572 573 for (int i = 0; i < m_numSteps; ++i) { 574 var x_begin = X().col(i); 575 var x_end = X().col(i + 1); 576 var u = U().col(i); 577 Variable dt = this.dt().get(0, i); 578 579 if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) { 580 x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)); 581 } else if (m_dynamicsType == DynamicsType.DISCRETE) { 582 x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)); 583 } 584 585 time = time.plus(dt); 586 } 587 } 588}