8#include <Eigen/SparseCore>
24template <
typename Scalar>
32 for (
const auto& node : m_top_list) {
33 m_col_list.emplace_back(node->col);
57 if (m_top_list.empty()) {
68 for (
auto& node : m_top_list) {
69 auto& lhs = node->args[0];
70 auto& rhs = node->args[1];
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
76 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
79 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
86 for (
int row = 0; row < grad.
rows(); ++row) {
87 grad[row] =
Variable{std::move(wrt[row].expr->adjoint_expr)};
93 for (
auto& node : m_top_list) {
94 node->adjoint_expr =
nullptr;
115 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
116 for (
const auto& elem : wrt) {
117 elem.expr->adjoint = Scalar(0);
121 if (m_top_list.empty()) {
126 m_top_list[0]->adjoint = Scalar(1);
129 for (
auto& node : m_top_list | std::views::drop(1)) {
130 node->adjoint = Scalar(0);
137 for (
const auto& node : m_top_list) {
138 auto& lhs = node->args[0];
139 auto& rhs = node->args[1];
141 if (lhs !=
nullptr) {
142 if (rhs !=
nullptr) {
144 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
145 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
148 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
154 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
155 for (
int col = 0; col < wrt.
rows(); ++col) {
156 const auto& node = wrt[col].expr;
159 if (node->adjoint != Scalar(0)) {
160 triplets.emplace_back(row, col, node->adjoint);
164 for (
size_t i = 0; i < m_top_list.size(); ++i) {
165 const auto& col = m_col_list[i];
166 const auto& node = m_top_list[i];
169 if (col != -1 && node->adjoint != Scalar(0)) {
170 triplets.emplace_back(row, col, node->adjoint);
#define slp_assert(condition)
Abort in C++.
Definition assert.hpp:25
An autodiff variable pointing to an expression node.
Definition variable.hpp:47
A matrix of autodiff variables.
Definition variable_matrix.hpp:33
int rows() const
Returns the number of rows in the matrix.
Definition variable_matrix.hpp:972
int cols() const
Returns the number of columns in the matrix.
Definition variable_matrix.hpp:977
GradientExpressionGraph(const Variable< Scalar > &root)
Generates the gradient graph for the given expression.
Definition gradient_expression_graph.hpp:30
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row, const VariableMatrix< Scalar > &wrt) const
Updates the adjoints in the expression graph (computes the gradient) then appends the adjoints of wrt...
Definition gradient_expression_graph.hpp:107
void update_values()
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Returns the variable's gradient tree.
Definition gradient_expression_graph.hpp:50
wpi::util::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
static constexpr empty_t empty
Designates an uninitialized VariableMatrix.
Definition empty.hpp:11
void update_values(const gch::small_vector< Expression< Scalar > * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:77
ExpressionPtr< Scalar > constant_ptr(Scalar value)
Creates an intrusive shared pointer to a constant expression.
Definition expression.hpp:417
gch::small_vector< Expression< Scalar > * > topological_sort(const ExpressionPtr< Scalar > &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20