001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009/** Represents a hermite spline of degree 3. */
010public class CubicHermiteSpline extends Spline {
011  private static SimpleMatrix hermiteBasis;
012  private final SimpleMatrix m_coefficients;
013
014  private final ControlVector m_initialControlVector;
015  private final ControlVector m_finalControlVector;
016
017  /**
018   * Constructs a cubic hermite spline with the specified control vectors. Each control vector
019   * contains info about the location of the point and its first derivative.
020   *
021   * @param xInitialControlVector The control vector for the initial point in the x dimension.
022   * @param xFinalControlVector The control vector for the final point in the x dimension.
023   * @param yInitialControlVector The control vector for the initial point in the y dimension.
024   * @param yFinalControlVector The control vector for the final point in the y dimension.
025   */
026  public CubicHermiteSpline(
027      double[] xInitialControlVector,
028      double[] xFinalControlVector,
029      double[] yInitialControlVector,
030      double[] yFinalControlVector) {
031    super(3);
032
033    // Populate the coefficients for the actual spline equations.
034    // Row 0 is x coefficients
035    // Row 1 is y coefficients
036    final var hermite = makeHermiteBasis();
037    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
038    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
039
040    final var xCoeffs = (hermite.mult(x)).transpose();
041    final var yCoeffs = (hermite.mult(y)).transpose();
042
043    m_coefficients = new SimpleMatrix(6, 4);
044
045    for (int i = 0; i < 4; i++) {
046      m_coefficients.set(0, i, xCoeffs.get(0, i));
047      m_coefficients.set(1, i, yCoeffs.get(0, i));
048
049      // Populate Row 2 and Row 3 with the derivatives of the equations above.
050      // Then populate row 4 and 5 with the second derivatives.
051      // Here, we are multiplying by (3 - i) to manually take the derivative. The
052      // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
053      // coefficient of the derivative, we can use the power rule and multiply
054      // the existing coefficient by its power.
055      m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
056      m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
057    }
058
059    for (int i = 0; i < 3; i++) {
060      // Here, we are multiplying by (2 - i) to manually take the derivative. The
061      // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
062      // coefficient of the derivative, we can use the power rule and multiply
063      // the existing coefficient by its power.
064      m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
065      m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
066    }
067
068    // Assign member variables.
069    m_initialControlVector = new ControlVector(xInitialControlVector, yInitialControlVector);
070    m_finalControlVector = new ControlVector(xFinalControlVector, yFinalControlVector);
071  }
072
073  /**
074   * Returns the coefficients matrix.
075   *
076   * @return The coefficients matrix.
077   */
078  @Override
079  public SimpleMatrix getCoefficients() {
080    return m_coefficients;
081  }
082
083  /**
084   * Returns the initial control vector that created this spline.
085   *
086   * @return The initial control vector that created this spline.
087   */
088  @Override
089  public ControlVector getInitialControlVector() {
090    return m_initialControlVector;
091  }
092
093  /**
094   * Returns the final control vector that created this spline.
095   *
096   * @return The final control vector that created this spline.
097   */
098  @Override
099  public ControlVector getFinalControlVector() {
100    return m_finalControlVector;
101  }
102
103  /**
104   * Returns the hermite basis matrix for cubic hermite spline interpolation.
105   *
106   * @return The hermite basis matrix for cubic hermite spline interpolation.
107   */
108  private SimpleMatrix makeHermiteBasis() {
109    if (hermiteBasis == null) {
110      // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
111      // the coefficients of the spline P(t) = a₃t³ + a₂t² + a₁t + a₀.
112      //
113      // P(i)    = P(0)  = a₀
114      // P'(i)   = P'(0) = a₁
115      // P(i+1)  = P(1)  = a₃ + a₂ + a₁ + a₀
116      // P'(i+1) = P'(1) = 3a₃ + 2a₂ + a₁
117      //
118      // [P(i)   ] = [0 0 0 1][a₃]
119      // [P'(i)  ] = [0 0 1 0][a₂]
120      // [P(i+1) ] = [1 1 1 1][a₁]
121      // [P'(i+1)] = [3 2 1 0][a₀]
122      //
123      // To solve for the coefficients, we can invert the 4x4 matrix and move it
124      // to the other side of the equation.
125      //
126      // [a₃] = [ 2  1 -2  1][P(i)   ]
127      // [a₂] = [-3 -2  3 -1][P'(i)  ]
128      // [a₁] = [ 0  1  0  0][P(i+1) ]
129      // [a₀] = [ 1  0  0  0][P'(i+1)]
130      hermiteBasis =
131          new SimpleMatrix(
132              4,
133              4,
134              true,
135              new double[] {
136                +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
137                +0.0, +0.0
138              });
139    }
140    return hermiteBasis;
141  }
142
143  /**
144   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
145   * constructor.
146   *
147   * @param initialVector The control vector for the initial point.
148   * @param finalVector The control vector for the final point.
149   * @return The control vector matrix for a dimension.
150   */
151  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
152    if (initialVector.length < 2 || finalVector.length < 2) {
153      throw new IllegalArgumentException("Size of vectors must be 2 or greater.");
154    }
155    return new SimpleMatrix(
156        4,
157        1,
158        true,
159        new double[] {
160          initialVector[0], initialVector[1],
161          finalVector[0], finalVector[1]
162        });
163  }
164}