001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.optimization;
006
007import java.util.function.Function;
008import java.util.function.ToDoubleFunction;
009
010/**
011 * An implementation of the Simulated Annealing stochastic nonlinear optimization method.
012 *
013 * @see <a
014 *     href="https://en.wikipedia.org/wiki/Simulated_annealing">https://en.wikipedia.org/wiki/Simulated_annealing</a>
015 * @param <State> The type of the state to optimize.
016 */
017public final class SimulatedAnnealing<State> {
018  private final double m_initialTemperature;
019  private final Function<State, State> m_neighbor;
020  private final ToDoubleFunction<State> m_cost;
021
022  /**
023   * Constructor for Simulated Annealing that can be used for the same functions but with different
024   * initial states.
025   *
026   * @param initialTemperature The initial temperature. Higher temperatures make it more likely a
027   *     worse state will be accepted during iteration, helping to avoid local minima. The
028   *     temperature is decreased over time.
029   * @param neighbor Function that generates a random neighbor of the current state.
030   * @param cost Function that returns the scalar cost of a state.
031   */
032  public SimulatedAnnealing(
033      double initialTemperature, Function<State, State> neighbor, ToDoubleFunction<State> cost) {
034    m_initialTemperature = initialTemperature;
035    m_neighbor = neighbor;
036    m_cost = cost;
037  }
038
039  /**
040   * Runs the Simulated Annealing algorithm.
041   *
042   * @param initialGuess The initial state.
043   * @param iterations Number of iterations to run the solver.
044   * @return The optimized stater.
045   */
046  public State solve(State initialGuess, int iterations) {
047    State minState = initialGuess;
048    double minCost = Double.MAX_VALUE;
049
050    State state = initialGuess;
051    double cost = m_cost.applyAsDouble(state);
052
053    for (int i = 0; i < iterations; ++i) {
054      double temperature = m_initialTemperature / i;
055
056      State proposedState = m_neighbor.apply(state);
057      double proposedCost = m_cost.applyAsDouble(proposedState);
058      double deltaCost = proposedCost - cost;
059
060      double acceptanceProbability = Math.exp(-deltaCost / temperature);
061
062      // If cost went down or random number exceeded acceptance probability,
063      // accept the proposed state
064      if (deltaCost < 0 || acceptanceProbability >= Math.random()) {
065        state = proposedState;
066        cost = proposedCost;
067      }
068
069      // If proposed cost is less than minimum, the proposed state becomes the
070      // new minimum
071      if (proposedCost < minCost) {
072        minState = proposedState;
073        minCost = proposedCost;
074      }
075    }
076
077    return minState;
078  }
079}