001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009public class CubicHermiteSpline extends Spline {
010  private static SimpleMatrix hermiteBasis;
011  private final SimpleMatrix m_coefficients;
012
013  private final ControlVector m_initialControlVector;
014  private final ControlVector m_finalControlVector;
015
016  /**
017   * Constructs a cubic hermite spline with the specified control vectors. Each control vector
018   * contains info about the location of the point and its first derivative.
019   *
020   * @param xInitialControlVector The control vector for the initial point in the x dimension.
021   * @param xFinalControlVector The control vector for the final point in the x dimension.
022   * @param yInitialControlVector The control vector for the initial point in the y dimension.
023   * @param yFinalControlVector The control vector for the final point in the y dimension.
024   */
025  public CubicHermiteSpline(
026      double[] xInitialControlVector,
027      double[] xFinalControlVector,
028      double[] yInitialControlVector,
029      double[] yFinalControlVector) {
030    super(3);
031
032    // Populate the coefficients for the actual spline equations.
033    // Row 0 is x coefficients
034    // Row 1 is y coefficients
035    final var hermite = makeHermiteBasis();
036    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
037    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
038
039    final var xCoeffs = (hermite.mult(x)).transpose();
040    final var yCoeffs = (hermite.mult(y)).transpose();
041
042    m_coefficients = new SimpleMatrix(6, 4);
043
044    for (int i = 0; i < 4; i++) {
045      m_coefficients.set(0, i, xCoeffs.get(0, i));
046      m_coefficients.set(1, i, yCoeffs.get(0, i));
047
048      // Populate Row 2 and Row 3 with the derivatives of the equations above.
049      // Then populate row 4 and 5 with the second derivatives.
050      // Here, we are multiplying by (3 - i) to manually take the derivative. The
051      // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
052      // coefficient of the derivative, we can use the power rule and multiply
053      // the existing coefficient by its power.
054      m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
055      m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
056    }
057
058    for (int i = 0; i < 3; i++) {
059      // Here, we are multiplying by (2 - i) to manually take the derivative. The
060      // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
061      // coefficient of the derivative, we can use the power rule and multiply
062      // the existing coefficient by its power.
063      m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
064      m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
065    }
066
067    // Assign member variables.
068    m_initialControlVector = new ControlVector(xInitialControlVector, yInitialControlVector);
069    m_finalControlVector = new ControlVector(xFinalControlVector, yFinalControlVector);
070  }
071
072  /**
073   * Returns the coefficients matrix.
074   *
075   * @return The coefficients matrix.
076   */
077  @Override
078  public SimpleMatrix getCoefficients() {
079    return m_coefficients;
080  }
081
082  /**
083   * Returns the initial control vector that created this spline.
084   *
085   * @return The initial control vector that created this spline.
086   */
087  @Override
088  public ControlVector getInitialControlVector() {
089    return m_initialControlVector;
090  }
091
092  /**
093   * Returns the final control vector that created this spline.
094   *
095   * @return The final control vector that created this spline.
096   */
097  @Override
098  public ControlVector getFinalControlVector() {
099    return m_finalControlVector;
100  }
101
102  /**
103   * Returns the hermite basis matrix for cubic hermite spline interpolation.
104   *
105   * @return The hermite basis matrix for cubic hermite spline interpolation.
106   */
107  private SimpleMatrix makeHermiteBasis() {
108    if (hermiteBasis == null) {
109      // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
110      // the coefficients of the spline P(t) = a₃t³ + a₂t² + a₁t + a₀.
111      //
112      // P(i)    = P(0)  = a₀
113      // P'(i)   = P'(0) = a₁
114      // P(i+1)  = P(1)  = a₃ + a₂ + a₁ + a₀
115      // P'(i+1) = P'(1) = 3a₃ + 2a₂ + a₁
116      //
117      // [P(i)   ] = [0 0 0 1][a₃]
118      // [P'(i)  ] = [0 0 1 0][a₂]
119      // [P(i+1) ] = [1 1 1 1][a₁]
120      // [P'(i+1)] = [3 2 1 0][a₀]
121      //
122      // To solve for the coefficients, we can invert the 4x4 matrix and move it
123      // to the other side of the equation.
124      //
125      // [a₃] = [ 2  1 -2  1][P(i)   ]
126      // [a₂] = [-3 -2  3 -1][P'(i)  ]
127      // [a₁] = [ 0  1  0  0][P(i+1) ]
128      // [a₀] = [ 1  0  0  0][P'(i+1)]
129      hermiteBasis =
130          new SimpleMatrix(
131              4,
132              4,
133              true,
134              new double[] {
135                +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
136                +0.0, +0.0
137              });
138    }
139    return hermiteBasis;
140  }
141
142  /**
143   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
144   * constructor.
145   *
146   * @param initialVector The control vector for the initial point.
147   * @param finalVector The control vector for the final point.
148   * @return The control vector matrix for a dimension.
149   */
150  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
151    if (initialVector.length < 2 || finalVector.length < 2) {
152      throw new IllegalArgumentException("Size of vectors must be 2 or greater.");
153    }
154    return new SimpleMatrix(
155        4,
156        1,
157        true,
158        new double[] {
159          initialVector[0], initialVector[1],
160          finalVector[0], finalVector[1]
161        });
162  }
163}