001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.optimization;
006
007import java.util.ArrayList;
008import java.util.function.Predicate;
009import org.ejml.data.DMatrixRMaj;
010import org.ejml.simple.SimpleMatrix;
011import org.wpilib.math.autodiff.ExpressionType;
012import org.wpilib.math.autodiff.NativeSparseTriplets;
013import org.wpilib.math.autodiff.Variable;
014import org.wpilib.math.autodiff.VariableMatrix;
015import org.wpilib.math.autodiff.VariablePool;
016import org.wpilib.math.optimization.solver.ExitStatus;
017import org.wpilib.math.optimization.solver.IterationInfo;
018import org.wpilib.math.optimization.solver.Options;
019
020/**
021 * This class allows the user to pose a constrained nonlinear optimization problem in natural
022 * mathematical notation and solve it.
023 *
024 * <p>This class supports problems of the form:
025 *
026 * <pre>
027 *       minₓ f(x)
028 * subject to cₑ(x) = 0
029 *            cᵢ(x) ≥ 0
030 * </pre>
031 *
032 * <p>where f(x) is the scalar cost function, x is the vector of decision variables (variables the
033 * solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x)
034 * are the equality constraints. Constraints are equations or inequalities of the decision variables
035 * that constrain what values the solver is allowed to use when searching for an optimal solution.
036 *
037 * <p>The nice thing about this class is users don't have to put their system in the form shown
038 * above manually; they can write it in natural mathematical form and it'll be converted for them.
039 */
040public class Problem implements AutoCloseable {
041  private long m_handle;
042
043  // The iteration callbacks
044  private final ArrayList<Predicate<IterationInfo>> m_iterationCallbacks = new ArrayList<>();
045
046  // Cleans up Variables allocated within Problem's scope
047  private final VariablePool m_pool = new VariablePool();
048
049  /** Construct the optimization problem. */
050  @SuppressWarnings("this-escape")
051  public Problem() {
052    m_handle = ProblemJNI.create();
053  }
054
055  @Override
056  public void close() {
057    if (m_handle != 0) {
058      ProblemJNI.destroy(m_handle);
059      m_handle = 0;
060
061      m_pool.close();
062    }
063  }
064
065  /**
066   * Creates a decision variable in the optimization problem.
067   *
068   * <p>Decision variables have an initial value of zero.
069   *
070   * @return A decision variable in the optimization problem.
071   */
072  public Variable decisionVariable() {
073    var handles = ProblemJNI.decisionVariable(m_handle, 1, 1);
074    return new Variable(Variable.HANDLE, handles[0]);
075  }
076
077  /**
078   * Creates a column vector of decision variables in the optimization problem.
079   *
080   * <p>Decision variables have an initial value of zero.
081   *
082   * @param rows Number of column vector rows.
083   * @return A column vector of decision variables in the optimization problem.
084   */
085  public VariableMatrix decisionVariable(int rows) {
086    return decisionVariable(rows, 1);
087  }
088
089  /**
090   * Creates a matrix of decision variables in the optimization problem.
091   *
092   * <p>Decision variables have an initial value of zero.
093   *
094   * @param rows Number of matrix rows.
095   * @param cols Number of matrix columns.
096   * @return A matrix of decision variables in the optimization problem.
097   */
098  public VariableMatrix decisionVariable(int rows, int cols) {
099    return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols));
100  }
101
102  /**
103   * Creates a symmetric matrix of decision variables in the optimization problem.
104   *
105   * <p>Variable instances are reused across the diagonal, which helps reduce problem
106   * dimensionality.
107   *
108   * <p>Decision variables have an initial value of zero.
109   *
110   * @param rows Number of matrix rows.
111   * @return A symmetric matrix of decision varaibles in the optimization problem.
112   */
113  public VariableMatrix symmetricDecisionVariable(int rows) {
114    return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows));
115  }
116
117  /**
118   * Tells the solver to minimize the output of the given cost function.
119   *
120   * <p>Note that this is optional. If only constraints are specified, the solver will find the
121   * closest solution to the initial conditions that's in the feasible set.
122   *
123   * @param cost The cost function to minimize.
124   */
125  public void minimize(Variable cost) {
126    ProblemJNI.minimize(m_handle, cost.getHandle());
127  }
128
129  /**
130   * Tells the solver to minimize the output of the given cost function.
131   *
132   * <p>Note that this is optional. If only constraints are specified, the solver will find the
133   * closest solution to the initial conditions that's in the feasible set.
134   *
135   * @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't
136   *     1x1.
137   */
138  public void minimize(VariableMatrix cost) {
139    assert cost.rows() == 1 && cost.cols() == 1;
140    minimize(cost.get(0, 0));
141  }
142
143  /**
144   * Tells the solver to maximize the output of the given objective function.
145   *
146   * <p>Note that this is optional. If only constraints are specified, the solver will find the
147   * closest solution to the initial conditions that's in the feasible set.
148   *
149   * @param objective The objective function to maximize.
150   */
151  public void maximize(Variable objective) {
152    ProblemJNI.maximize(m_handle, objective.getHandle());
153  }
154
155  /**
156   * Tells the solver to maximize the output of the given objective function.
157   *
158   * <p>Note that this is optional. If only constraints are specified, the solver will find the
159   * closest solution to the initial conditions that's in the feasible set.
160   *
161   * @param objective The objective function to maximize. An assertion is raised if the
162   *     VariableMatrix isn't 1x1.
163   */
164  public void maximize(VariableMatrix objective) {
165    assert objective.rows() == 1 && objective.cols() == 1;
166    maximize(objective.get(0, 0));
167  }
168
169  /**
170   * Tells the solver to solve the problem while satisfying the given equality constraint.
171   *
172   * @param constraint The constraint to satisfy.
173   */
174  public void subjectTo(EqualityConstraints constraint) {
175    var constraintHandles = new long[constraint.constraints.length];
176    for (int i = 0; i < constraintHandles.length; ++i) {
177      constraintHandles[i] = constraint.constraints[i].getHandle();
178    }
179    ProblemJNI.subjectToEq(m_handle, constraintHandles);
180  }
181
182  /**
183   * Tells the solver to solve the problem while satisfying the given inequality constraint.
184   *
185   * @param constraint The constraint to satisfy.
186   */
187  public void subjectTo(InequalityConstraints constraint) {
188    var constraintHandles = new long[constraint.constraints.length];
189    for (int i = 0; i < constraintHandles.length; ++i) {
190      constraintHandles[i] = constraint.constraints[i].getHandle();
191    }
192    ProblemJNI.subjectToIneq(m_handle, constraintHandles);
193  }
194
195  /**
196   * Returns the cost function's type.
197   *
198   * @return The cost function's type.
199   */
200  public ExpressionType costFunctionType() {
201    return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle));
202  }
203
204  /**
205   * Returns the type of the highest order equality constraint.
206   *
207   * @return The type of the highest order equality constraint.
208   */
209  public ExpressionType equalityConstraintType() {
210    return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle));
211  }
212
213  /**
214   * Returns the type of the highest order inequality constraint.
215   *
216   * @return The type of the highest order inequality constraint.
217   */
218  public ExpressionType inequalityConstraintType() {
219    return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle));
220  }
221
222  /**
223   * Solves the optimization problem. The solution will be stored in the original variables used to
224   * construct the problem.
225   *
226   * @return The solver status.
227   */
228  public ExitStatus solve() {
229    return solve(new Options());
230  }
231
232  /**
233   * Solves the optimization problem. The solution will be stored in the original variables used to
234   * construct the problem.
235   *
236   * @param options Solver options.
237   * @return The solver status.
238   */
239  public ExitStatus solve(Options options) {
240    return ExitStatus.fromInt(
241        ProblemJNI.solve(
242            this,
243            m_handle,
244            options.tolerance,
245            options.maxIterations,
246            options.timeout,
247            options.feasibleIPM,
248            options.diagnostics));
249  }
250
251  /**
252   * Adds a callback to be called at the beginning of each solver iteration.
253   *
254   * <p>The callback for this overload should return bool.
255   *
256   * @param callback The callback. Returning true from the callback causes the solver to exit early
257   *     with the solution it has so far.
258   */
259  public void addCallback(Predicate<IterationInfo> callback) {
260    m_iterationCallbacks.add(callback);
261  }
262
263  /** Clears the registered callbacks. */
264  public void clearCallbacks() {
265    m_iterationCallbacks.clear();
266  }
267
268  /**
269   * Runs the registered callbacks.
270   *
271   * <p>This function is called by native code in ProblemJNI.
272   *
273   * @param numEqualityConstraints The number of equality constraints.
274   * @param numInequalityConstraints The number of inequality constraints.
275   * @param iteration The solver iteration.
276   * @param x The decision variable values.
277   * @param gTriplets Gradient triplets.
278   * @param HTriplets Hessian triplets.
279   * @param A_eTriplets Equality constraint Jacobian triplets.
280   * @param A_iTriplets Inequality constraint Jacobian triplets.
281   * @return True if the solver shold exit early.
282   */
283  boolean runCallbacks(
284      int numEqualityConstraints,
285      int numInequalityConstraints,
286      int iteration,
287      double[] x,
288      NativeSparseTriplets gTriplets,
289      NativeSparseTriplets HTriplets,
290      NativeSparseTriplets A_eTriplets,
291      NativeSparseTriplets A_iTriplets) {
292    if (m_iterationCallbacks.isEmpty()) {
293      return false;
294    }
295
296    var info =
297        new IterationInfo(
298            iteration,
299            new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)),
300            gTriplets.toSimpleMatrix(x.length, 1),
301            HTriplets.toSimpleMatrix(x.length, x.length),
302            A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length),
303            A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length));
304
305    for (var callback : m_iterationCallbacks) {
306      if (callback.test(info)) {
307        return true;
308      }
309    }
310    return false;
311  }
312}