001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import edu.wpi.first.math.spline.proto.CubicHermiteSplineProto;
008import edu.wpi.first.math.spline.struct.CubicHermiteSplineStruct;
009import edu.wpi.first.util.protobuf.ProtobufSerializable;
010import edu.wpi.first.util.struct.StructSerializable;
011import org.ejml.simple.SimpleMatrix;
012
013/** Represents a hermite spline of degree 3. */
014public class CubicHermiteSpline extends Spline implements ProtobufSerializable, StructSerializable {
015  private static SimpleMatrix hermiteBasis;
016  private final SimpleMatrix m_coefficients;
017
018  /** The control vector for the initial point in the x dimension. DO NOT MODIFY THIS ARRAY! */
019  public final double[] xInitialControlVector;
020
021  /** The control vector for the final point in the x dimension. DO NOT MODIFY THIS ARRAY! */
022  public final double[] xFinalControlVector;
023
024  /** The control vector for the initial point in the y dimension. DO NOT MODIFY THIS ARRAY! */
025  public final double[] yInitialControlVector;
026
027  /** The control vector for the final point in the y dimension. DO NOT MODIFY THIS ARRAY! */
028  public final double[] yFinalControlVector;
029
030  private final ControlVector m_initialControlVector;
031  private final ControlVector m_finalControlVector;
032
033  /**
034   * Constructs a cubic hermite spline with the specified control vectors. Each control vector
035   * contains info about the location of the point and its first derivative.
036   *
037   * @param xInitialControlVector The control vector for the initial point in the x dimension.
038   * @param xFinalControlVector The control vector for the final point in the x dimension.
039   * @param yInitialControlVector The control vector for the initial point in the y dimension.
040   * @param yFinalControlVector The control vector for the final point in the y dimension.
041   */
042  @SuppressWarnings("PMD.ArrayIsStoredDirectly")
043  public CubicHermiteSpline(
044      double[] xInitialControlVector,
045      double[] xFinalControlVector,
046      double[] yInitialControlVector,
047      double[] yFinalControlVector) {
048    super(3);
049    this.xInitialControlVector = xInitialControlVector;
050    this.xFinalControlVector = xFinalControlVector;
051    this.yInitialControlVector = yInitialControlVector;
052    this.yFinalControlVector = yFinalControlVector;
053
054    // Populate the coefficients for the actual spline equations.
055    // Row 0 is x coefficients
056    // Row 1 is y coefficients
057    final var hermite = makeHermiteBasis();
058    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
059    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
060
061    final var xCoeffs = (hermite.mult(x)).transpose();
062    final var yCoeffs = (hermite.mult(y)).transpose();
063
064    m_coefficients = new SimpleMatrix(6, 4);
065
066    for (int i = 0; i < 4; i++) {
067      m_coefficients.set(0, i, xCoeffs.get(0, i));
068      m_coefficients.set(1, i, yCoeffs.get(0, i));
069
070      // Populate Row 2 and Row 3 with the derivatives of the equations above.
071      // Then populate row 4 and 5 with the second derivatives.
072      // Here, we are multiplying by (3 - i) to manually take the derivative. The
073      // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
074      // coefficient of the derivative, we can use the power rule and multiply
075      // the existing coefficient by its power.
076      m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
077      m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
078    }
079
080    for (int i = 0; i < 3; i++) {
081      // Here, we are multiplying by (2 - i) to manually take the derivative. The
082      // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
083      // coefficient of the derivative, we can use the power rule and multiply
084      // the existing coefficient by its power.
085      m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
086      m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
087    }
088
089    // Assign member variables.
090    m_initialControlVector = new ControlVector(xInitialControlVector, yInitialControlVector);
091    m_finalControlVector = new ControlVector(xFinalControlVector, yFinalControlVector);
092  }
093
094  /**
095   * Returns the coefficients matrix.
096   *
097   * @return The coefficients matrix.
098   */
099  @Override
100  public SimpleMatrix getCoefficients() {
101    return m_coefficients;
102  }
103
104  /**
105   * Returns the initial control vector that created this spline.
106   *
107   * @return The initial control vector that created this spline.
108   */
109  @Override
110  public ControlVector getInitialControlVector() {
111    return m_initialControlVector;
112  }
113
114  /**
115   * Returns the final control vector that created this spline.
116   *
117   * @return The final control vector that created this spline.
118   */
119  @Override
120  public ControlVector getFinalControlVector() {
121    return m_finalControlVector;
122  }
123
124  /**
125   * Returns the hermite basis matrix for cubic hermite spline interpolation.
126   *
127   * @return The hermite basis matrix for cubic hermite spline interpolation.
128   */
129  @SuppressWarnings("PMD.UnnecessaryVarargsArrayCreation")
130  private SimpleMatrix makeHermiteBasis() {
131    if (hermiteBasis == null) {
132      // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
133      // the coefficients of the spline P(t) = a₃t³ + a₂t² + a₁t + a₀.
134      //
135      // P(i)    = P(0)  = a₀
136      // P'(i)   = P'(0) = a₁
137      // P(i+1)  = P(1)  = a₃ + a₂ + a₁ + a₀
138      // P'(i+1) = P'(1) = 3a₃ + 2a₂ + a₁
139      //
140      // [P(i)   ] = [0 0 0 1][a₃]
141      // [P'(i)  ] = [0 0 1 0][a₂]
142      // [P(i+1) ] = [1 1 1 1][a₁]
143      // [P'(i+1)] = [3 2 1 0][a₀]
144      //
145      // To solve for the coefficients, we can invert the 4x4 matrix and move it
146      // to the other side of the equation.
147      //
148      // [a₃] = [ 2  1 -2  1][P(i)   ]
149      // [a₂] = [-3 -2  3 -1][P'(i)  ]
150      // [a₁] = [ 0  1  0  0][P(i+1) ]
151      // [a₀] = [ 1  0  0  0][P'(i+1)]
152      hermiteBasis =
153          new SimpleMatrix(
154              4,
155              4,
156              true,
157              new double[] {
158                +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
159                +0.0, +0.0
160              });
161    }
162    return hermiteBasis;
163  }
164
165  /**
166   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
167   * constructor.
168   *
169   * @param initialVector The control vector for the initial point.
170   * @param finalVector The control vector for the final point.
171   * @return The control vector matrix for a dimension.
172   */
173  @SuppressWarnings("PMD.UnnecessaryVarargsArrayCreation")
174  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
175    if (initialVector.length < 2 || finalVector.length < 2) {
176      throw new IllegalArgumentException("Size of vectors must be 2 or greater.");
177    }
178    return new SimpleMatrix(
179        4,
180        1,
181        true,
182        new double[] {
183          initialVector[0], initialVector[1],
184          finalVector[0], finalVector[1]
185        });
186  }
187
188  /** CubicHermiteSpline struct for serialization. */
189  public static final CubicHermiteSplineProto proto = new CubicHermiteSplineProto();
190
191  /** CubicHermiteSpline protobuf for serialization. */
192  public static final CubicHermiteSplineStruct struct = new CubicHermiteSplineStruct();
193}