001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.optimization; 006 007import java.util.ArrayList; 008import java.util.function.Predicate; 009import org.ejml.data.DMatrixRMaj; 010import org.ejml.simple.SimpleMatrix; 011import org.wpilib.math.autodiff.ExpressionType; 012import org.wpilib.math.autodiff.NativeSparseTriplets; 013import org.wpilib.math.autodiff.Variable; 014import org.wpilib.math.autodiff.VariableMatrix; 015import org.wpilib.math.autodiff.VariablePool; 016import org.wpilib.math.optimization.solver.ExitStatus; 017import org.wpilib.math.optimization.solver.IterationInfo; 018import org.wpilib.math.optimization.solver.Options; 019 020/** 021 * This class allows the user to pose a constrained nonlinear optimization problem in natural 022 * mathematical notation and solve it. 023 * 024 * <p>This class supports problems of the form: 025 * 026 * <pre> 027 * minₓ f(x) 028 * subject to cₑ(x) = 0 029 * cᵢ(x) ≥ 0 030 * </pre> 031 * 032 * <p>where f(x) is the scalar cost function, x is the vector of decision variables (variables the 033 * solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x) 034 * are the equality constraints. Constraints are equations or inequalities of the decision variables 035 * that constrain what values the solver is allowed to use when searching for an optimal solution. 036 * 037 * <p>The nice thing about this class is users don't have to put their system in the form shown 038 * above manually; they can write it in natural mathematical form and it'll be converted for them. 039 */ 040public class Problem implements AutoCloseable { 041 private long m_handle; 042 043 // The iteration callbacks 044 private final ArrayList<Predicate<IterationInfo>> m_iterationCallbacks = new ArrayList<>(); 045 046 // Cleans up Variables allocated within Problem's scope 047 private final VariablePool m_pool = new VariablePool(); 048 049 /** Construct the optimization problem. */ 050 @SuppressWarnings("this-escape") 051 public Problem() { 052 m_handle = ProblemJNI.create(); 053 } 054 055 @Override 056 public void close() { 057 if (m_handle != 0) { 058 ProblemJNI.destroy(m_handle); 059 m_handle = 0; 060 061 m_pool.close(); 062 } 063 } 064 065 /** 066 * Creates a decision variable in the optimization problem. 067 * 068 * <p>Decision variables have an initial value of zero. 069 * 070 * @return A decision variable in the optimization problem. 071 */ 072 public Variable decisionVariable() { 073 var handles = ProblemJNI.decisionVariable(m_handle, 1, 1); 074 return new Variable(Variable.HANDLE, handles[0]); 075 } 076 077 /** 078 * Creates a column vector of decision variables in the optimization problem. 079 * 080 * <p>Decision variables have an initial value of zero. 081 * 082 * @param rows Number of column vector rows. 083 * @return A column vector of decision variables in the optimization problem. 084 */ 085 public VariableMatrix decisionVariable(int rows) { 086 return decisionVariable(rows, 1); 087 } 088 089 /** 090 * Creates a matrix of decision variables in the optimization problem. 091 * 092 * <p>Decision variables have an initial value of zero. 093 * 094 * @param rows Number of matrix rows. 095 * @param cols Number of matrix columns. 096 * @return A matrix of decision variables in the optimization problem. 097 */ 098 public VariableMatrix decisionVariable(int rows, int cols) { 099 return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols)); 100 } 101 102 /** 103 * Creates a symmetric matrix of decision variables in the optimization problem. 104 * 105 * <p>Variable instances are reused across the diagonal, which helps reduce problem 106 * dimensionality. 107 * 108 * <p>Decision variables have an initial value of zero. 109 * 110 * @param rows Number of matrix rows. 111 * @return A symmetric matrix of decision varaibles in the optimization problem. 112 */ 113 public VariableMatrix symmetricDecisionVariable(int rows) { 114 return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows)); 115 } 116 117 /** 118 * Tells the solver to minimize the output of the given cost function. 119 * 120 * <p>Note that this is optional. If only constraints are specified, the solver will find the 121 * closest solution to the initial conditions that's in the feasible set. 122 * 123 * @param cost The cost function to minimize. 124 */ 125 public void minimize(Variable cost) { 126 ProblemJNI.minimize(m_handle, cost.getHandle()); 127 } 128 129 /** 130 * Tells the solver to minimize the output of the given cost function. 131 * 132 * <p>Note that this is optional. If only constraints are specified, the solver will find the 133 * closest solution to the initial conditions that's in the feasible set. 134 * 135 * @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't 136 * 1x1. 137 */ 138 public void minimize(VariableMatrix cost) { 139 assert cost.rows() == 1 && cost.cols() == 1; 140 minimize(cost.get(0, 0)); 141 } 142 143 /** 144 * Tells the solver to maximize the output of the given objective function. 145 * 146 * <p>Note that this is optional. If only constraints are specified, the solver will find the 147 * closest solution to the initial conditions that's in the feasible set. 148 * 149 * @param objective The objective function to maximize. 150 */ 151 public void maximize(Variable objective) { 152 ProblemJNI.maximize(m_handle, objective.getHandle()); 153 } 154 155 /** 156 * Tells the solver to maximize the output of the given objective function. 157 * 158 * <p>Note that this is optional. If only constraints are specified, the solver will find the 159 * closest solution to the initial conditions that's in the feasible set. 160 * 161 * @param objective The objective function to maximize. An assertion is raised if the 162 * VariableMatrix isn't 1x1. 163 */ 164 public void maximize(VariableMatrix objective) { 165 assert objective.rows() == 1 && objective.cols() == 1; 166 maximize(objective.get(0, 0)); 167 } 168 169 /** 170 * Tells the solver to solve the problem while satisfying the given equality constraint. 171 * 172 * @param constraint The constraint to satisfy. 173 */ 174 public void subjectTo(EqualityConstraints constraint) { 175 var constraintHandles = new long[constraint.constraints.length]; 176 for (int i = 0; i < constraintHandles.length; ++i) { 177 constraintHandles[i] = constraint.constraints[i].getHandle(); 178 } 179 ProblemJNI.subjectToEq(m_handle, constraintHandles); 180 } 181 182 /** 183 * Tells the solver to solve the problem while satisfying the given inequality constraint. 184 * 185 * @param constraint The constraint to satisfy. 186 */ 187 public void subjectTo(InequalityConstraints constraint) { 188 var constraintHandles = new long[constraint.constraints.length]; 189 for (int i = 0; i < constraintHandles.length; ++i) { 190 constraintHandles[i] = constraint.constraints[i].getHandle(); 191 } 192 ProblemJNI.subjectToIneq(m_handle, constraintHandles); 193 } 194 195 /** 196 * Returns the cost function's type. 197 * 198 * @return The cost function's type. 199 */ 200 public ExpressionType costFunctionType() { 201 return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle)); 202 } 203 204 /** 205 * Returns the type of the highest order equality constraint. 206 * 207 * @return The type of the highest order equality constraint. 208 */ 209 public ExpressionType equalityConstraintType() { 210 return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle)); 211 } 212 213 /** 214 * Returns the type of the highest order inequality constraint. 215 * 216 * @return The type of the highest order inequality constraint. 217 */ 218 public ExpressionType inequalityConstraintType() { 219 return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle)); 220 } 221 222 /** 223 * Solves the optimization problem. The solution will be stored in the original variables used to 224 * construct the problem. 225 * 226 * @return The solver status. 227 */ 228 public ExitStatus solve() { 229 return solve(new Options()); 230 } 231 232 /** 233 * Solves the optimization problem. The solution will be stored in the original variables used to 234 * construct the problem. 235 * 236 * @param options Solver options. 237 * @return The solver status. 238 */ 239 public ExitStatus solve(Options options) { 240 return ExitStatus.fromInt( 241 ProblemJNI.solve( 242 this, 243 m_handle, 244 options.tolerance, 245 options.maxIterations, 246 options.timeout, 247 options.feasibleIPM, 248 options.diagnostics)); 249 } 250 251 /** 252 * Adds a callback to be called at the beginning of each solver iteration. 253 * 254 * <p>The callback for this overload should return bool. 255 * 256 * @param callback The callback. Returning true from the callback causes the solver to exit early 257 * with the solution it has so far. 258 */ 259 public void addCallback(Predicate<IterationInfo> callback) { 260 m_iterationCallbacks.add(callback); 261 } 262 263 /** Clears the registered callbacks. */ 264 public void clearCallbacks() { 265 m_iterationCallbacks.clear(); 266 } 267 268 /** 269 * Runs the registered callbacks. 270 * 271 * <p>This function is called by native code in ProblemJNI. 272 * 273 * @param numEqualityConstraints The number of equality constraints. 274 * @param numInequalityConstraints The number of inequality constraints. 275 * @param iteration The solver iteration. 276 * @param x The decision variable values. 277 * @param gTriplets Gradient triplets. 278 * @param HTriplets Hessian triplets. 279 * @param A_eTriplets Equality constraint Jacobian triplets. 280 * @param A_iTriplets Inequality constraint Jacobian triplets. 281 * @return True if the solver shold exit early. 282 */ 283 boolean runCallbacks( 284 int numEqualityConstraints, 285 int numInequalityConstraints, 286 int iteration, 287 double[] x, 288 NativeSparseTriplets gTriplets, 289 NativeSparseTriplets HTriplets, 290 NativeSparseTriplets A_eTriplets, 291 NativeSparseTriplets A_iTriplets) { 292 if (m_iterationCallbacks.isEmpty()) { 293 return false; 294 } 295 296 var info = 297 new IterationInfo( 298 iteration, 299 new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)), 300 gTriplets.toSimpleMatrix(x.length, 1), 301 HTriplets.toSimpleMatrix(x.length, x.length), 302 A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length), 303 A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length)); 304 305 for (var callback : m_iterationCallbacks) { 306 if (callback.test(info)) { 307 return true; 308 } 309 } 310 return false; 311 } 312}