001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.optimization;
006
007import static org.wpilib.math.optimization.Constraints.eq;
008import static org.wpilib.math.optimization.Constraints.ge;
009import static org.wpilib.math.optimization.Constraints.le;
010
011import java.util.function.BiConsumer;
012import java.util.function.BiFunction;
013import org.ejml.simple.SimpleMatrix;
014import org.wpilib.math.autodiff.Variable;
015import org.wpilib.math.autodiff.VariableBlock;
016import org.wpilib.math.autodiff.VariableMatrix;
017import org.wpilib.math.optimization.ocp.ConstraintEvaluationFunction;
018import org.wpilib.math.optimization.ocp.DynamicsFunction;
019import org.wpilib.math.optimization.ocp.DynamicsType;
020import org.wpilib.math.optimization.ocp.TimestepMethod;
021import org.wpilib.math.optimization.ocp.TranscriptionMethod;
022
023/**
024 * This class allows the user to pose and solve a constrained optimal control problem (OCP) in a
025 * variety of ways.
026 *
027 * <p>The system is transcripted by one of three methods (direct transcription, direct collocation,
028 * or single-shooting) and additional constraints can be added.
029 *
030 * <p>In direct transcription, each state is a decision variable constrained to the integrated
031 * dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of
032 * cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend
033 * explicitly as a function of all previous states and all previous inputs.
034 *
035 * <p>Explicit ODEs are integrated using RK4.
036 *
037 * <p>For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state
038 * transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ).
039 *
040 * <p>Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use
041 * either an ODE or state transition function.
042 *
043 * <p>https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method.
044 */
045public class OCP extends Problem {
046  private int m_numSteps;
047
048  private DynamicsFunction m_dynamics;
049  private DynamicsType m_dynamicsType;
050
051  private VariableMatrix m_X;
052  private VariableMatrix m_U;
053  private VariableMatrix m_DT;
054
055  /**
056   * Builds an optimization problem using a system evolution function (explicit ODE or discrete
057   * state transition function).
058   *
059   * @param numStates The number of system states.
060   * @param numInputs The number of system inputs.
061   * @param dt The timestep for fixed-step integration.
062   * @param numSteps The number of control points.
063   * @param dynamics Function representing an explicit or implicit ODE, or a discrete state
064   *     transition function.
065   *     <ul>
066   *       <li>Explicit: dx/dt = f(x, u, *)
067   *       <li>Implicit: f([x dx/dt]', u, *) = 0
068   *       <li>State transition: xₖ₊₁ = f(xₖ, uₖ)
069   *     </ul>
070   *
071   * @param dynamicsType The type of system evolution function.
072   * @param timestepMethod The timestep method.
073   * @param transcriptionMethod The transcription method.
074   */
075  public OCP(
076      int numStates,
077      int numInputs,
078      double dt,
079      int numSteps,
080      BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> dynamics,
081      DynamicsType dynamicsType,
082      TimestepMethod timestepMethod,
083      TranscriptionMethod transcriptionMethod) {
084    this(
085        numStates,
086        numInputs,
087        dt,
088        numSteps,
089        (Variable t, VariableMatrix x, VariableMatrix u, Variable _dt) -> dynamics.apply(x, u),
090        dynamicsType,
091        timestepMethod,
092        transcriptionMethod);
093  }
094
095  /**
096   * Builds an optimization problem using a system evolution function (explicit ODE or discrete
097   * state transition function).
098   *
099   * @param numStates The number of system states.
100   * @param numInputs The number of system inputs.
101   * @param dt The timestep for fixed-step integration.
102   * @param numSteps The number of control points.
103   * @param dynamics Function representing an explicit or implicit ODE, or a discrete state
104   *     transition function.
105   *     <ul>
106   *       <li>Explicit: dx/dt = f(t, x, u, *)
107   *       <li>Implicit: f(t, [x dx/dt]', u, *) = 0
108   *       <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
109   *     </ul>
110   *
111   * @param dynamicsType The type of system evolution function.
112   * @param timestepMethod The timestep method.
113   * @param transcriptionMethod The transcription method.
114   */
115  @SuppressWarnings("this-escape")
116  public OCP(
117      int numStates,
118      int numInputs,
119      double dt,
120      int numSteps,
121      DynamicsFunction dynamics,
122      DynamicsType dynamicsType,
123      TimestepMethod timestepMethod,
124      TranscriptionMethod transcriptionMethod) {
125    m_numSteps = numSteps;
126    m_dynamics = dynamics;
127    m_dynamicsType = dynamicsType;
128
129    // u is numSteps + 1 so that the final constraint function evaluation works
130    m_U = decisionVariable(numInputs, m_numSteps + 1);
131
132    if (timestepMethod == TimestepMethod.FIXED) {
133      m_DT = new VariableMatrix(1, m_numSteps + 1);
134      for (int i = 0; i < numSteps + 1; ++i) {
135        m_DT.set(0, i, dt);
136      }
137    } else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) {
138      Variable single_dt = decisionVariable();
139      single_dt.setValue(dt);
140
141      // Set the member variable matrix to track the decision variable
142      m_DT = new VariableMatrix(1, m_numSteps + 1);
143      for (int i = 0; i < numSteps + 1; ++i) {
144        m_DT.set(0, i, single_dt);
145      }
146    } else if (timestepMethod == TimestepMethod.VARIABLE) {
147      m_DT = decisionVariable(1, m_numSteps + 1);
148      for (int i = 0; i < numSteps + 1; ++i) {
149        m_DT.get(0, i).setValue(dt);
150      }
151    }
152
153    if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) {
154      m_X = decisionVariable(numStates, m_numSteps + 1);
155      constrainDirectTranscription();
156    } else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
157      m_X = decisionVariable(numStates, m_numSteps + 1);
158      constrainDirectCollocation();
159    } else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) {
160      // In single-shooting the states aren't decision variables, but instead
161      // depend on the input and previous states
162      m_X = new VariableMatrix(numStates, m_numSteps + 1);
163      constrainSingleShooting();
164    }
165  }
166
167  /**
168   * Constrains the initial state.
169   *
170   * @param initialState the initial state to constrain to.
171   */
172  public void constrainInitialState(double initialState) {
173    subjectTo(eq(this.initialState(), initialState));
174  }
175
176  /**
177   * Constrains the initial state.
178   *
179   * @param initialState the initial state to constrain to.
180   */
181  public void constrainInitialState(Variable initialState) {
182    subjectTo(eq(this.initialState(), initialState));
183  }
184
185  /**
186   * Constrains the initial state.
187   *
188   * @param initialState the initial state to constrain to.
189   */
190  public void constrainInitialState(SimpleMatrix initialState) {
191    subjectTo(eq(this.initialState(), initialState));
192  }
193
194  /**
195   * Constrains the initial state.
196   *
197   * @param initialState the initial state to constrain to.
198   */
199  public void constrainInitialState(VariableMatrix initialState) {
200    subjectTo(eq(this.initialState(), initialState));
201  }
202
203  /**
204   * Constrains the initial state.
205   *
206   * @param initialState the initial state to constrain to.
207   */
208  public void constrainInitialState(VariableBlock initialState) {
209    subjectTo(eq(this.initialState(), initialState));
210  }
211
212  /**
213   * Constrains the final state.
214   *
215   * @param finalState the final state to constrain to.
216   */
217  public void constrainFinalState(double finalState) {
218    subjectTo(eq(this.finalState(), finalState));
219  }
220
221  /**
222   * Constrains the final state.
223   *
224   * @param finalState the final state to constrain to.
225   */
226  public void constrainFinalState(Variable finalState) {
227    subjectTo(eq(this.finalState(), finalState));
228  }
229
230  /**
231   * Constrains the final state.
232   *
233   * @param finalState the final state to constrain to.
234   */
235  public void constrainFinalState(SimpleMatrix finalState) {
236    subjectTo(eq(this.finalState(), finalState));
237  }
238
239  /**
240   * Constrains the final state.
241   *
242   * @param finalState the final state to constrain to.
243   */
244  public void constrainFinalState(VariableMatrix finalState) {
245    subjectTo(eq(this.finalState(), finalState));
246  }
247
248  /**
249   * Constrains the final state.
250   *
251   * @param finalState the final state to constrain to.
252   */
253  public void constrainFinalState(VariableBlock finalState) {
254    subjectTo(eq(this.finalState(), finalState));
255  }
256
257  /**
258   * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
259   * corresponding state and input VariableMatrices.
260   *
261   * @param callback The callback f(x, u) where x is the state and u is the input vector.
262   */
263  public void forEachStep(BiConsumer<VariableMatrix, VariableMatrix> callback) {
264    for (int i = 0; i < m_numSteps + 1; ++i) {
265      var x = X().col(i);
266      var u = U().col(i);
267      callback.accept(new VariableMatrix(x), new VariableMatrix(u));
268    }
269  }
270
271  /**
272   * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
273   * corresponding state and input VariableMatrices.
274   *
275   * @param callback The callback f(t, x, u, dt) where t is time, x is the state vector, u is the
276   *     input vector, and dt is the timestep duration.
277   */
278  public void forEachStep(ConstraintEvaluationFunction callback) {
279    var time = new Variable(0.0);
280
281    for (int i = 0; i < m_numSteps + 1; ++i) {
282      var x = X().col(i);
283      var u = U().col(i);
284      var dt = this.dt().get(0, i);
285      callback.accept(time, new VariableMatrix(x), new VariableMatrix(u), dt);
286
287      time = time.plus(dt);
288    }
289  }
290
291  /**
292   * Sets a lower bound on the input.
293   *
294   * @param lowerBound The lower bound that inputs must always be above. Must be shaped
295   *     (numInputs)x1.
296   */
297  public void setLowerInputBound(double lowerBound) {
298    for (int i = 0; i < m_numSteps + 1; ++i) {
299      subjectTo(ge(U().col(i), lowerBound));
300    }
301  }
302
303  /**
304   * Sets a lower bound on the input.
305   *
306   * @param lowerBound The lower bound that inputs must always be above. Must be shaped
307   *     (numInputs)x1.
308   */
309  public void setLowerInputBound(Variable lowerBound) {
310    for (int i = 0; i < m_numSteps + 1; ++i) {
311      subjectTo(ge(U().col(i), lowerBound));
312    }
313  }
314
315  /**
316   * Sets a lower bound on the input.
317   *
318   * @param lowerBound The lower bound that inputs must always be above. Must be shaped
319   *     (numInputs)x1.
320   */
321  public void setLowerInputBound(SimpleMatrix lowerBound) {
322    for (int i = 0; i < m_numSteps + 1; ++i) {
323      subjectTo(ge(U().col(i), lowerBound));
324    }
325  }
326
327  /**
328   * Sets a lower bound on the input.
329   *
330   * @param lowerBound The lower bound that inputs must always be above. Must be shaped
331   *     (numInputs)x1.
332   */
333  public void setLowerInputBound(VariableMatrix lowerBound) {
334    for (int i = 0; i < m_numSteps + 1; ++i) {
335      subjectTo(ge(U().col(i), lowerBound));
336    }
337  }
338
339  /**
340   * Sets a lower bound on the input.
341   *
342   * @param lowerBound The lower bound that inputs must always be above. Must be shaped
343   *     (numInputs)x1.
344   */
345  public void setLowerInputBound(VariableBlock lowerBound) {
346    for (int i = 0; i < m_numSteps + 1; ++i) {
347      subjectTo(ge(U().col(i), lowerBound));
348    }
349  }
350
351  /**
352   * Sets an upper bound on the input.
353   *
354   * @param upperBound The upper bound that inputs must always be below. Must be shaped
355   *     (numInputs)x1.
356   */
357  public void setUpperInputBound(double upperBound) {
358    for (int i = 0; i < m_numSteps + 1; ++i) {
359      subjectTo(le(U().col(i), upperBound));
360    }
361  }
362
363  /**
364   * Sets an upper bound on the input.
365   *
366   * @param upperBound The upper bound that inputs must always be below. Must be shaped
367   *     (numInputs)x1.
368   */
369  public void setUpperInputBound(Variable upperBound) {
370    for (int i = 0; i < m_numSteps + 1; ++i) {
371      subjectTo(le(U().col(i), upperBound));
372    }
373  }
374
375  /**
376   * Sets an upper bound on the input.
377   *
378   * @param upperBound The upper bound that inputs must always be below. Must be shaped
379   *     (numInputs)x1.
380   */
381  public void setUpperInputBound(SimpleMatrix upperBound) {
382    for (int i = 0; i < m_numSteps + 1; ++i) {
383      subjectTo(le(U().col(i), upperBound));
384    }
385  }
386
387  /**
388   * Sets an upper bound on the input.
389   *
390   * @param upperBound The upper bound that inputs must always be below. Must be shaped
391   *     (numInputs)x1.
392   */
393  public void setUpperInputBound(VariableMatrix upperBound) {
394    for (int i = 0; i < m_numSteps + 1; ++i) {
395      subjectTo(le(U().col(i), upperBound));
396    }
397  }
398
399  /**
400   * Sets an upper bound on the input.
401   *
402   * @param upperBound The upper bound that inputs must always be below. Must be shaped
403   *     (numInputs)x1.
404   */
405  public void setUpperInputBound(VariableBlock upperBound) {
406    for (int i = 0; i < m_numSteps + 1; ++i) {
407      subjectTo(le(U().col(i), upperBound));
408    }
409  }
410
411  /**
412   * Sets a lower bound on the timestep.
413   *
414   * @param minTimestep The minimum timestep in seconds.
415   */
416  public void setMinTimestep(double minTimestep) {
417    subjectTo(ge(dt(), minTimestep));
418  }
419
420  /**
421   * Sets an upper bound on the timestep.
422   *
423   * @param maxTimestep The maximum timestep in seconds.
424   */
425  public void setMaxTimestep(double maxTimestep) {
426    subjectTo(le(dt(), maxTimestep));
427  }
428
429  /**
430   * Gets the state variables. After the problem is solved, this will contain the optimized
431   * trajectory.
432   *
433   * <p>Shaped (numStates)x(numSteps+1).
434   *
435   * @return The state variable matrix.
436   */
437  public VariableMatrix X() {
438    return m_X;
439  }
440
441  /**
442   * Gets the input variables. After the problem is solved, this will contain the inputs
443   * corresponding to the optimized trajectory.
444   *
445   * <p>Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory.
446   *
447   * @return The input variable matrix.
448   */
449  public VariableMatrix U() {
450    return m_U;
451  }
452
453  /**
454   * Gets the timestep variables. After the problem is solved, this will contain the timesteps
455   * corresponding to the optimized trajectory.
456   *
457   * <p>Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory.
458   *
459   * @return The timestep variable matrix.
460   */
461  public VariableMatrix dt() {
462    return m_DT;
463  }
464
465  /**
466   * Gets the initial state in the trajectory.
467   *
468   * @return The initial state of the trajectory.
469   */
470  public VariableMatrix initialState() {
471    return new VariableMatrix(m_X.col(0));
472  }
473
474  /**
475   * Gets the final state in the trajectory.
476   *
477   * @return The final state of the trajectory.
478   */
479  public VariableMatrix finalState() {
480    return new VariableMatrix(m_X.col(m_numSteps));
481  }
482
483  /**
484   * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt.
485   *
486   * @param f The function to integrate. It must take two arguments x and u.
487   * @param x The initial value of x.
488   * @param u The value u held constant over the integration period.
489   * @param t0 The initial time.
490   * @param dt The time over which to integrate.
491   */
492  private static VariableMatrix rk4(
493      DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) {
494    var halfdt = dt.times(0.5);
495    VariableMatrix k1 = f.apply(t0, x, u, dt);
496    VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt);
497    VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt);
498    VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt);
499
500    return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0)));
501  }
502
503  /** Applies direct collocation dynamics constraints. */
504  private void constrainDirectCollocation() {
505    assert m_dynamicsType == DynamicsType.EXPLICIT_ODE;
506
507    var time = new Variable(0.0);
508
509    // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/
510    for (int i = 0; i < m_numSteps; ++i) {
511      Variable h = dt().get(0, i);
512
513      var f = m_dynamics;
514
515      var t_begin = time;
516      var t_end = t_begin.plus(h);
517
518      var x_begin = X().col(i);
519      var x_end = X().col(i + 1);
520
521      var u_begin = U().col(i);
522      var u_end = U().col(i + 1);
523
524      var xdot_begin =
525          f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h);
526      var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h);
527      var xdot_c =
528          x_begin
529              .minus(x_end)
530              .times(new Variable(-3).div(h.times(2)))
531              .minus(xdot_begin.plus(xdot_end).times(0.25));
532
533      var t_c = t_begin.plus(h.times(0.5));
534      var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8)));
535      var u_c = u_begin.plus(u_end).times(0.5);
536
537      subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h)));
538
539      time = time.plus(h);
540    }
541  }
542
543  /** Applies direct transcription dynamics constraints. */
544  private void constrainDirectTranscription() {
545    var time = new Variable(0.0);
546
547    for (int i = 0; i < m_numSteps; ++i) {
548      var x_begin = X().col(i);
549      var x_end = X().col(i + 1);
550      var u = U().col(i);
551      Variable dt = this.dt().get(0, i);
552
553      if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
554        subjectTo(
555            eq(
556                x_end,
557                rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)));
558      } else if (m_dynamicsType == DynamicsType.DISCRETE) {
559        subjectTo(
560            eq(
561                x_end,
562                m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)));
563      }
564
565      time = time.plus(dt);
566    }
567  }
568
569  /** Applies single shooting dynamics constraints. */
570  private void constrainSingleShooting() {
571    var time = new Variable(0.0);
572
573    for (int i = 0; i < m_numSteps; ++i) {
574      var x_begin = X().col(i);
575      var x_end = X().col(i + 1);
576      var u = U().col(i);
577      Variable dt = this.dt().get(0, i);
578
579      if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
580        x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt));
581      } else if (m_dynamicsType == DynamicsType.DISCRETE) {
582        x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt));
583      }
584
585      time = time.plus(dt);
586    }
587  }
588}