001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.path;
006
007import edu.wpi.first.math.Num;
008import edu.wpi.first.math.Vector;
009import edu.wpi.first.math.geometry.Pose2d;
010import edu.wpi.first.math.optimization.SimulatedAnnealing;
011import java.util.Arrays;
012import java.util.Collections;
013import java.util.function.ToDoubleBiFunction;
014
015/**
016 * Given a list of poses, this class finds the shortest possible route that visits each pose exactly
017 * once and returns to the origin pose.
018 *
019 * @see <a
020 *     href="https://en.wikipedia.org/wiki/Travelling_salesman_problem">https://en.wikipedia.org/wiki/Travelling_salesman_problem</a>
021 */
022public class TravelingSalesman {
023  // Default cost is 2D distance between poses
024  private final ToDoubleBiFunction<Pose2d, Pose2d> m_cost;
025
026  /**
027   * Constructs a traveling salesman problem solver with a cost function defined as the 2D distance
028   * between poses.
029   */
030  public TravelingSalesman() {
031    this((Pose2d a, Pose2d b) -> Math.hypot(a.getX() - b.getX(), a.getY() - b.getY()));
032  }
033
034  /**
035   * Constructs a traveling salesman problem solver with a user-provided cost function.
036   *
037   * @param cost Function that returns the cost between two poses. The sum of the costs for every
038   *     pair of poses is minimized.
039   */
040  public TravelingSalesman(ToDoubleBiFunction<Pose2d, Pose2d> cost) {
041    m_cost = cost;
042  }
043
044  /**
045   * Finds the path through every pose that minimizes the cost. The first pose in the returned array
046   * is the first pose that was passed in.
047   *
048   * @param <Poses> A Num defining the length of the path and the number of poses.
049   * @param poses An array of Pose2ds the path must pass through.
050   * @param iterations The number of times the solver attempts to find a better random neighbor.
051   * @return The optimized path as an array of Pose2ds.
052   */
053  public <Poses extends Num> Pose2d[] solve(Pose2d[] poses, int iterations) {
054    var solver =
055        new SimulatedAnnealing<>(
056            1.0,
057            this::neighbor,
058            // Total cost is sum of all costs between adjacent pose pairs in path
059            (Vector<Poses> state) -> {
060              double sum = 0.0;
061              for (int i = 0; i < state.getNumRows(); ++i) {
062                sum +=
063                    m_cost.applyAsDouble(
064                        poses[(int) state.get(i, 0)],
065                        poses[(int) state.get((i + 1) % poses.length, 0)]);
066              }
067              return sum;
068            });
069
070    var initial = new Vector<Poses>(() -> poses.length);
071    for (int i = 0; i < poses.length; ++i) {
072      initial.set(i, 0, i);
073    }
074
075    var indices = solver.solve(initial, iterations);
076
077    var solution = new Pose2d[poses.length];
078    for (int i = 0; i < poses.length; ++i) {
079      solution[i] = poses[(int) indices.get(i, 0)];
080    }
081
082    // Rotate solution list until solution[0] = poses[0]
083    Collections.rotate(Arrays.asList(solution), -Arrays.asList(solution).indexOf(poses[0]));
084
085    return solution;
086  }
087
088  /**
089   * A random neighbor is generated to try to replace the current one.
090   *
091   * @param state A vector that is a list of indices that defines the path through the path array.
092   * @return Generates a random neighbor of the current state by flipping a random range in the path
093   *     array.
094   */
095  private <Poses extends Num> Vector<Poses> neighbor(Vector<Poses> state) {
096    var proposedState = new Vector<>(state);
097
098    int rangeStart = (int) (Math.random() * (state.getNumRows() - 1));
099    int rangeEnd = (int) (Math.random() * (state.getNumRows() - 1));
100    if (rangeEnd < rangeStart) {
101      int temp = rangeEnd;
102      rangeEnd = rangeStart;
103      rangeStart = temp;
104    }
105
106    for (int i = rangeStart; i <= (rangeStart + rangeEnd) / 2; ++i) {
107      double temp = proposedState.get(i, 0);
108      proposedState.set(i, 0, state.get(rangeEnd - (i - rangeStart), 0));
109      proposedState.set(rangeEnd - (i - rangeStart), 0, temp);
110    }
111
112    return proposedState;
113  }
114}