001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.autodiff;
006
007import org.ejml.simple.SimpleMatrix;
008
009/**
010 * This class calculates the gradient of a variable with respect to a vector of variables.
011 *
012 * <p>The gradient is only recomputed if the variable expression is quadratic or higher order.
013 */
014public class Gradient implements AutoCloseable {
015  private long m_handle;
016  private int m_rows;
017
018  /**
019   * Constructs a Gradient object.
020   *
021   * @param variable Variable of which to compute the gradient.
022   * @param wrt Variable with respect to which to compute the gradient.
023   */
024  public Gradient(Variable variable, Variable wrt) {
025    this(variable, new VariableMatrix(wrt));
026  }
027
028  /**
029   * Constructs a Gradient object.
030   *
031   * @param variable Variable of which to compute the gradient.
032   * @param wrt Vector of variables with respect to which to compute the gradient.
033   */
034  public Gradient(Variable variable, VariableMatrix wrt) {
035    assert wrt.cols() == 1;
036
037    m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles());
038    m_rows = wrt.rows();
039  }
040
041  /**
042   * Constructs a Gradient object.
043   *
044   * @param variable Variable of which to compute the gradient.
045   * @param wrt Vector of variables with respect to which to compute the gradient.
046   */
047  public Gradient(Variable variable, VariableBlock wrt) {
048    this(variable, new VariableMatrix(wrt));
049  }
050
051  @Override
052  public void close() {
053    if (m_handle != 0) {
054      GradientJNI.destroy(m_handle);
055      m_handle = 0;
056    }
057  }
058
059  /**
060   * Returns the gradient as a VariableMatrix.
061   *
062   * <p>This is useful when constructing optimization problems with derivatives in them.
063   *
064   * @return The gradient as a VariableMatrix.
065   */
066  public VariableMatrix get() {
067    return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle));
068  }
069
070  /**
071   * Evaluates the gradient at wrt's value.
072   *
073   * @return The gradient at wrt's value.
074   */
075  public SimpleMatrix value() {
076    return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1);
077  }
078}