001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.commands3;
006
007import static edu.wpi.first.util.ErrorMessages.requireNonNullParam;
008
009import java.util.Collection;
010import java.util.HashSet;
011import java.util.LinkedHashSet;
012import java.util.Set;
013
014/**
015 * A type of command that runs multiple other commands in parallel. The group will end after all
016 * required commands have completed; if no command is explicitly required for completion, then the
017 * group will end after any command in the group finishes. Any commands still running when the group
018 * has reached its completion condition will be canceled.
019 */
020public final class ParallelGroup implements Command {
021  private final Collection<Command> m_requiredCommands = new HashSet<>();
022  private final Collection<Command> m_optionalCommands = new LinkedHashSet<>();
023  private final Set<Mechanism> m_requirements = new HashSet<>();
024  private final String m_name;
025  private final int m_priority;
026
027  /**
028   * Creates a new parallel group.
029   *
030   * @param name The name of the group.
031   * @param requiredCommands The commands that are required to complete for the group to finish. If
032   *     this is empty, then the group will finish when <i>any</i> optional command completes.
033   * @param optionalCommands The commands that do not need to complete for the group to finish. If
034   *     this is empty, then the group will finish when <i>all</i> required commands complete.
035   */
036  ParallelGroup(
037      String name, Collection<Command> requiredCommands, Collection<Command> optionalCommands) {
038    requireNonNullParam(name, "name", "ParallelGroup");
039    requireNonNullParam(requiredCommands, "requiredCommands", "ParallelGroup");
040    requireNonNullParam(optionalCommands, "optionalCommands", "ParallelGroup");
041
042    int i = 0;
043    for (Command requiredCommand : requiredCommands) {
044      requireNonNullParam(requiredCommand, "requiredCommands[" + i + "]", "ParallelGroup");
045      i++;
046    }
047
048    i = 0;
049    for (Command c : optionalCommands) {
050      requireNonNullParam(c, "optionalCommands[" + i + "]", "ParallelGroup");
051      i++;
052    }
053
054    var allCommands = new LinkedHashSet<Command>();
055    allCommands.addAll(requiredCommands);
056    allCommands.addAll(optionalCommands);
057
058    ConflictDetector.throwIfConflicts(allCommands);
059
060    m_name = name;
061    m_optionalCommands.addAll(optionalCommands);
062    m_requiredCommands.addAll(requiredCommands);
063
064    for (var command : allCommands) {
065      m_requirements.addAll(command.requirements());
066    }
067
068    m_priority =
069        allCommands.stream().mapToInt(Command::priority).max().orElse(Command.DEFAULT_PRIORITY);
070  }
071
072  @Override
073  public void run(Coroutine coroutine) {
074    coroutine.fork(m_optionalCommands);
075
076    if (m_requiredCommands.isEmpty()) {
077      // No required commands - just wait for the first optional command to finish
078      coroutine.awaitAny(m_optionalCommands);
079    } else {
080      // Wait for every required command to finish
081      coroutine.awaitAll(m_requiredCommands);
082    }
083
084    // The scheduler will cancel any optional child commands that are still running when
085    // the `run` method returns
086  }
087
088  @Override
089  public String name() {
090    return m_name;
091  }
092
093  @Override
094  public Set<Mechanism> requirements() {
095    return m_requirements;
096  }
097
098  @Override
099  public int priority() {
100    return m_priority;
101  }
102
103  @Override
104  public String toString() {
105    return "ParallelGroup[name=" + m_name + "]";
106  }
107}