8#include <Eigen/SparseCore>
31 for (
const auto& node : m_top_list) {
32 m_col_list.emplace_back(node->col);
59 if (m_top_list.empty()) {
70 for (
auto& node : m_top_list) {
71 auto& lhs = node->args[0];
72 auto& rhs = node->args[1];
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
77 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
84 for (
int row = 0; row < grad.rows(); ++row) {
85 grad[row] =
Variable{std::move(wrt[row].expr->adjoint_expr)};
91 for (
auto& node : m_top_list) {
92 node->adjoint_expr =
nullptr;
114 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
115 for (
const auto& elem : wrt) {
116 elem.expr->adjoint = 0.0;
120 if (m_top_list.empty()) {
125 m_top_list[0]->adjoint = 1.0;
128 for (
auto& node : m_top_list | std::views::drop(1)) {
136 for (
const auto& node : m_top_list) {
137 auto& lhs = node->args[0];
138 auto& rhs = node->args[1];
140 if (lhs !=
nullptr) {
141 if (rhs !=
nullptr) {
142 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
143 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
145 lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint);
151 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
152 for (
int col = 0; col < wrt.
rows(); ++col) {
153 const auto& node = wrt[col].expr;
156 if (node->adjoint != 0.0) {
157 triplets.emplace_back(row, col, node->adjoint);
161 for (
size_t i = 0; i < m_top_list.size(); ++i) {
162 const auto& col = m_col_list[i];
163 const auto& node = m_top_list[i];
166 if (col != -1 && node->adjoint != 0.0) {
167 triplets.emplace_back(row, col, node->adjoint);
#define slp_assert(condition)
Abort in C++.
Definition assert.hpp:27
An autodiff variable pointing to an expression node.
Definition variable.hpp:40
A matrix of autodiff variables.
Definition variable_matrix.hpp:29
int rows() const
Returns the number of rows in the matrix.
Definition variable_matrix.hpp:914
int cols() const
Returns the number of columns in the matrix.
Definition variable_matrix.hpp:921
static constexpr empty_t empty
Designates an uninitialized VariableMatrix.
Definition variable_matrix.hpp:39
This class is an adaptor type that performs value updates of an expression's adjoint graph.
Definition adjoint_expression_graph.hpp:22
VariableMatrix generate_gradient_tree(const VariableMatrix &wrt) const
Returns the variable's gradient tree.
Definition adjoint_expression_graph.hpp:53
void update_values()
Update the values of all nodes in this adjoint graph based on the values of their dependent nodes.
Definition adjoint_expression_graph.hpp:40
AdjointExpressionGraph(const Variable &root)
Generates the adjoint graph for the given expression.
Definition adjoint_expression_graph.hpp:29
void append_adjoint_triplets(gch::small_vector< Eigen::Triplet< double > > &triplets, int row, const VariableMatrix &wrt) const
Updates the adjoints in the expression graph (computes the gradient) then appends the adjoints of wrt...
Definition adjoint_expression_graph.hpp:107
wpi::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
gch::small_vector< Expression * > topological_sort(const ExpressionPtr &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20
void update_values(const gch::small_vector< Expression * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:77
static ExpressionPtr make_expression_ptr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition expression.hpp:48