8#include <Eigen/SparseCore>
30 for (
const auto& node : m_top_list) {
31 m_col_list.emplace_back(node->col);
56 if (m_top_list.empty()) {
67 for (
auto& node : m_top_list) {
68 auto& lhs = node->args[0];
69 auto& rhs = node->args[1];
72 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
74 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
81 for (
int row = 0; row < grad.rows(); ++row) {
82 grad[row] =
Variable{std::move(wrt[row].expr->adjoint_expr)};
88 for (
auto& node : m_top_list) {
89 node->adjoint_expr =
nullptr;
111 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
112 for (
const auto& elem : wrt) {
113 elem.expr->adjoint = 0.0;
117 if (m_top_list.empty()) {
122 m_top_list[0]->adjoint = 1.0;
125 for (
auto& node : m_top_list | std::views::drop(1)) {
133 for (
const auto& node : m_top_list) {
134 auto& lhs = node->args[0];
135 auto& rhs = node->args[1];
137 if (lhs !=
nullptr) {
138 if (rhs !=
nullptr) {
139 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
140 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
142 lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint);
148 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
149 for (
int col = 0; col < wrt.
rows(); ++col) {
150 const auto& node = wrt[col].expr;
153 if (node->adjoint != 0.0) {
154 triplets.emplace_back(row, col, node->adjoint);
158 for (
size_t i = 0; i < m_top_list.size(); ++i) {
159 const auto& col = m_col_list[i];
160 const auto& node = m_top_list[i];
163 if (col != -1 && node->adjoint != 0.0) {
164 triplets.emplace_back(row, col, node->adjoint);
An autodiff variable pointing to an expression node.
Definition variable.hpp:40
A matrix of autodiff variables.
Definition variable_matrix.hpp:29
int rows() const
Returns the number of rows in the matrix.
Definition variable_matrix.hpp:914
static constexpr empty_t empty
Designates an uninitialized VariableMatrix.
Definition variable_matrix.hpp:39
This class is an adaptor type that performs value updates of an expression's adjoint graph.
Definition adjoint_expression_graph.hpp:21
VariableMatrix generate_gradient_tree(const VariableMatrix &wrt) const
Returns the variable's gradient tree.
Definition adjoint_expression_graph.hpp:52
void update_values()
Update the values of all nodes in this adjoint graph based on the values of their dependent nodes.
Definition adjoint_expression_graph.hpp:39
AdjointExpressionGraph(const Variable &root)
Generates the adjoint graph for the given expression.
Definition adjoint_expression_graph.hpp:28
void append_adjoint_triplets(gch::small_vector< Eigen::Triplet< double > > &triplets, int row, const VariableMatrix &wrt) const
Updates the adjoints in the expression graph (computes the gradient) then appends the adjoints of wrt...
Definition adjoint_expression_graph.hpp:104
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1198
Definition expression_graph.hpp:11
gch::small_vector< Expression * > topological_sort(const ExpressionPtr &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20
void update_values(const gch::small_vector< Expression * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:78
static ExpressionPtr make_expression_ptr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition expression.hpp:48