001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.autodiff; 006 007import org.ejml.simple.SimpleMatrix; 008 009/** 010 * This class calculates the gradient of a variable with respect to a vector of variables. 011 * 012 * <p>The gradient is only recomputed if the variable expression is quadratic or higher order. 013 */ 014public class Gradient implements AutoCloseable { 015 private long m_handle; 016 private int m_rows; 017 018 /** 019 * Constructs a Gradient object. 020 * 021 * @param variable Variable of which to compute the gradient. 022 * @param wrt Variable with respect to which to compute the gradient. 023 */ 024 public Gradient(Variable variable, Variable wrt) { 025 this(variable, new VariableMatrix(wrt)); 026 } 027 028 /** 029 * Constructs a Gradient object. 030 * 031 * @param variable Variable of which to compute the gradient. 032 * @param wrt Vector of variables with respect to which to compute the gradient. 033 */ 034 public Gradient(Variable variable, VariableMatrix wrt) { 035 assert wrt.cols() == 1; 036 037 m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles()); 038 m_rows = wrt.rows(); 039 } 040 041 /** 042 * Constructs a Gradient object. 043 * 044 * @param variable Variable of which to compute the gradient. 045 * @param wrt Vector of variables with respect to which to compute the gradient. 046 */ 047 public Gradient(Variable variable, VariableBlock wrt) { 048 this(variable, new VariableMatrix(wrt)); 049 } 050 051 @Override 052 public void close() { 053 if (m_handle != 0) { 054 GradientJNI.destroy(m_handle); 055 m_handle = 0; 056 } 057 } 058 059 /** 060 * Returns the gradient as a VariableMatrix. 061 * 062 * <p>This is useful when constructing optimization problems with derivatives in them. 063 * 064 * @return The gradient as a VariableMatrix. 065 */ 066 public VariableMatrix get() { 067 return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle)); 068 } 069 070 /** 071 * Evaluates the gradient at wrt's value. 072 * 073 * @return The gradient at wrt's value. 074 */ 075 public SimpleMatrix value() { 076 return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1); 077 } 078}