001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import java.util.function.BiFunction;
012import java.util.function.Function;
013
014public final class NumericalJacobian {
015  private NumericalJacobian() {
016    // Utility Class.
017  }
018
019  private static final double kEpsilon = 1e-5;
020
021  /**
022   * Computes the numerical Jacobian with respect to x for f(x).
023   *
024   * @param <Rows> Number of rows in the result of f(x).
025   * @param <States> Num representing the number of rows in the output of f.
026   * @param <Cols> Number of columns in the result of f(x).
027   * @param rows Number of rows in the result of f(x).
028   * @param cols Number of columns in the result of f(x).
029   * @param f Vector-valued function from which to compute the Jacobian.
030   * @param x Vector argument.
031   * @return The numerical Jacobian with respect to x for f(x, u, ...).
032   */
033  public static <Rows extends Num, Cols extends Num, States extends Num>
034      Matrix<Rows, Cols> numericalJacobian(
035          Nat<Rows> rows,
036          Nat<Cols> cols,
037          Function<Matrix<Cols, N1>, Matrix<States, N1>> f,
038          Matrix<Cols, N1> x) {
039    var result = new Matrix<>(rows, cols);
040
041    for (int i = 0; i < cols.getNum(); i++) {
042      var dxPlus = x.copy();
043      var dxMinus = x.copy();
044      dxPlus.set(i, 0, dxPlus.get(i, 0) + kEpsilon);
045      dxMinus.set(i, 0, dxMinus.get(i, 0) - kEpsilon);
046      var dF = f.apply(dxPlus).minus(f.apply(dxMinus)).div(2 * kEpsilon);
047
048      result.setColumn(i, Matrix.changeBoundsUnchecked(dF));
049    }
050
051    return result;
052  }
053
054  /**
055   * Returns numerical Jacobian with respect to x for f(x, u, ...).
056   *
057   * @param <Rows> Number of rows in the result of f(x, u).
058   * @param <States> Number of rows in x.
059   * @param <Inputs> Number of rows in the second input to f.
060   * @param <Outputs> Num representing the rows in the output of f.
061   * @param rows Number of rows in the result of f(x, u).
062   * @param states Number of rows in x.
063   * @param f Vector-valued function from which to compute Jacobian.
064   * @param x State vector.
065   * @param u Input vector.
066   * @return The numerical Jacobian with respect to x for f(x, u, ...).
067   */
068  public static <Rows extends Num, States extends Num, Inputs extends Num, Outputs extends Num>
069      Matrix<Rows, States> numericalJacobianX(
070          Nat<Rows> rows,
071          Nat<States> states,
072          BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> f,
073          Matrix<States, N1> x,
074          Matrix<Inputs, N1> u) {
075    return numericalJacobian(rows, states, _x -> f.apply(_x, u), x);
076  }
077
078  /**
079   * Returns the numerical Jacobian with respect to u for f(x, u).
080   *
081   * @param <States> The states of the system.
082   * @param <Inputs> The inputs to the system.
083   * @param <Rows> Number of rows in the result of f(x, u).
084   * @param rows Number of rows in the result of f(x, u).
085   * @param inputs Number of rows in u.
086   * @param f Vector-valued function from which to compute the Jacobian.
087   * @param x State vector.
088   * @param u Input vector.
089   * @return the numerical Jacobian with respect to u for f(x, u).
090   */
091  public static <Rows extends Num, States extends Num, Inputs extends Num>
092      Matrix<Rows, Inputs> numericalJacobianU(
093          Nat<Rows> rows,
094          Nat<Inputs> inputs,
095          BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
096          Matrix<States, N1> x,
097          Matrix<Inputs, N1> u) {
098    return numericalJacobian(rows, inputs, _u -> f.apply(x, _u), u);
099  }
100}