001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.autodiff;
006
007import org.ejml.simple.SimpleMatrix;
008
009/**
010 * This class calculates the Hessian of a variable with respect to a vector of variables.
011 *
012 * <p>The gradient tree is cached so subsequent Hessian calculations are faster, and the Hessian is
013 * only recomputed if the variable expression is nonlinear.
014 */
015public class Hessian implements AutoCloseable {
016  private long m_handle;
017  private int m_rows;
018
019  /**
020   * Constructs a Hessian object.
021   *
022   * @param variable Variable of which to compute the Hessian.
023   * @param wrt Variable with respect to which to compute the Hessian.
024   */
025  public Hessian(Variable variable, Variable wrt) {
026    this(variable, new VariableMatrix(wrt));
027  }
028
029  /**
030   * Constructs a Hessian object.
031   *
032   * @param variable Variable of which to compute the Hessian.
033   * @param wrt Vector of variables with respect to which to compute the Hessian.
034   */
035  public Hessian(Variable variable, VariableMatrix wrt) {
036    assert wrt.cols() == 1;
037
038    m_handle = HessianJNI.create(variable.getHandle(), wrt.getHandles());
039    m_rows = wrt.rows();
040  }
041
042  /**
043   * Constructs a Hessian object.
044   *
045   * @param variable Variable of which to compute the Hessian.
046   * @param wrt Vector of variables with respect to which to compute the Hessian.
047   */
048  public Hessian(Variable variable, VariableBlock wrt) {
049    this(variable, new VariableMatrix(wrt));
050  }
051
052  @Override
053  public void close() {
054    if (m_handle != 0) {
055      HessianJNI.destroy(m_handle);
056      m_handle = 0;
057    }
058  }
059
060  /**
061   * Returns the Hessian as a VariableMatrix.
062   *
063   * <p>This is useful when constructing optimization problems with derivatives in them.
064   *
065   * @return The Hessian as a VariableMatrix.
066   */
067  public VariableMatrix get() {
068    return new VariableMatrix(m_rows, m_rows, HessianJNI.get(m_handle));
069  }
070
071  /**
072   * Evaluates the Hessian at wrt's value.
073   *
074   * @return The Hessian at wrt's value.
075   */
076  public SimpleMatrix value() {
077    return HessianJNI.value(m_handle).toSimpleMatrix(m_rows, m_rows);
078  }
079}