WPILibC++ 2027.0.0-alpha-3
Loading...
Searching...
No Matches
adjoint_expression_graph.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
10
15
16namespace slp::detail {
17
18/**
19 * This class is an adaptor type that performs value updates of an expression's
20 * adjoint graph.
21 */
23 public:
24 /**
25 * Generates the adjoint graph for the given expression.
26 *
27 * @param root The root node of the expression.
28 */
29 explicit AdjointExpressionGraph(const Variable& root)
30 : m_top_list{topological_sort(root.expr)} {
31 for (const auto& node : m_top_list) {
32 m_col_list.emplace_back(node->col);
33 }
34 }
35
36 /**
37 * Update the values of all nodes in this adjoint graph based on the values of
38 * their dependent nodes.
39 */
40 void update_values() { detail::update_values(m_top_list); }
41
42 /**
43 * Returns the variable's gradient tree.
44 *
45 * This function lazily allocates variables, so elements of the returned
46 * VariableMatrix will be empty if the corresponding element of wrt had no
47 * adjoint. Ensure Variable::expr isn't nullptr before calling member
48 * functions.
49 *
50 * @param wrt Variables with respect to which to compute the gradient.
51 * @return The variable's gradient tree.
52 */
54 slp_assert(wrt.cols() == 1);
55
56 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
57 // for background on reverse accumulation automatic differentiation.
58
59 if (m_top_list.empty()) {
61 }
62
63 // Set root node's adjoint to 1 since df/df is 1
64 m_top_list[0]->adjoint_expr = make_expression_ptr<ConstExpression>(1.0);
65
66 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
67 // multiplied by dy/dx. If there are multiple "paths" from the root node to
68 // variable; the variable's adjoint is the sum of each path's adjoint
69 // contribution.
70 for (auto& node : m_top_list) {
71 auto& lhs = node->args[0];
72 auto& rhs = node->args[1];
73
74 if (lhs != nullptr) {
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
76 if (rhs != nullptr) {
77 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
78 }
79 }
80 }
81
82 // Move gradient tree to return value
84 for (int row = 0; row < grad.rows(); ++row) {
85 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
86 }
87
88 // Unlink adjoints to avoid circular references between them and their
89 // parent expressions. This ensures all expressions are returned to the free
90 // list.
91 for (auto& node : m_top_list) {
92 node->adjoint_expr = nullptr;
93 }
94
95 return grad;
96 }
97
98 /**
99 * Updates the adjoints in the expression graph (computes the gradient) then
100 * appends the adjoints of wrt to the sparse matrix triplets.
101 *
102 * @param triplets The sparse matrix triplets.
103 * @param row The row of wrt.
104 * @param wrt Vector of variables with respect to which to compute the
105 * Jacobian.
106 */
108 gch::small_vector<Eigen::Triplet<double>>& triplets, int row,
109 const VariableMatrix& wrt) const {
110 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
111 // for background on reverse accumulation automatic differentiation.
112
113 // If wrt has fewer nodes than graph, zero wrt's adjoints
114 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
115 for (const auto& elem : wrt) {
116 elem.expr->adjoint = 0.0;
117 }
118 }
119
120 if (m_top_list.empty()) {
121 return;
122 }
123
124 // Set root node's adjoint to 1 since df/df is 1
125 m_top_list[0]->adjoint = 1.0;
126
127 // Zero the rest of the adjoints
128 for (auto& node : m_top_list | std::views::drop(1)) {
129 node->adjoint = 0.0;
130 }
131
132 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
133 // multiplied by dy/dx. If there are multiple "paths" from the root node to
134 // variable; the variable's adjoint is the sum of each path's adjoint
135 // contribution.
136 for (const auto& node : m_top_list) {
137 auto& lhs = node->args[0];
138 auto& rhs = node->args[1];
139
140 if (lhs != nullptr) {
141 if (rhs != nullptr) {
142 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
143 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
144 } else {
145 lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint);
146 }
147 }
148 }
149
150 // If wrt has fewer nodes than graph, iterate over wrt
151 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
152 for (int col = 0; col < wrt.rows(); ++col) {
153 const auto& node = wrt[col].expr;
154
155 // Append adjoints of wrt to sparse matrix triplets
156 if (node->adjoint != 0.0) {
157 triplets.emplace_back(row, col, node->adjoint);
158 }
159 }
160 } else {
161 for (size_t i = 0; i < m_top_list.size(); ++i) {
162 const auto& col = m_col_list[i];
163 const auto& node = m_top_list[i];
164
165 // Append adjoints of wrt to sparse matrix triplets
166 if (col != -1 && node->adjoint != 0.0) {
167 triplets.emplace_back(row, col, node->adjoint);
168 }
169 }
170 }
171 }
172
173 private:
174 // Topological sort of graph from parent to child
176
177 // List that maps nodes to their respective column
178 gch::small_vector<int> m_col_list;
179};
180
181} // namespace slp::detail
#define slp_assert(condition)
Abort in C++.
Definition assert.hpp:27
An autodiff variable pointing to an expression node.
Definition variable.hpp:40
A matrix of autodiff variables.
Definition variable_matrix.hpp:29
int rows() const
Returns the number of rows in the matrix.
Definition variable_matrix.hpp:914
int cols() const
Returns the number of columns in the matrix.
Definition variable_matrix.hpp:921
static constexpr empty_t empty
Designates an uninitialized VariableMatrix.
Definition variable_matrix.hpp:39
This class is an adaptor type that performs value updates of an expression's adjoint graph.
Definition adjoint_expression_graph.hpp:22
VariableMatrix generate_gradient_tree(const VariableMatrix &wrt) const
Returns the variable's gradient tree.
Definition adjoint_expression_graph.hpp:53
void update_values()
Update the values of all nodes in this adjoint graph based on the values of their dependent nodes.
Definition adjoint_expression_graph.hpp:40
AdjointExpressionGraph(const Variable &root)
Generates the adjoint graph for the given expression.
Definition adjoint_expression_graph.hpp:29
void append_adjoint_triplets(gch::small_vector< Eigen::Triplet< double > > &triplets, int row, const VariableMatrix &wrt) const
Updates the adjoints in the expression graph (computes the gradient) then appends the adjoints of wrt...
Definition adjoint_expression_graph.hpp:107
wpi::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
gch::small_vector< Expression * > topological_sort(const ExpressionPtr &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20
void update_values(const gch::small_vector< Expression * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:77
static ExpressionPtr make_expression_ptr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition expression.hpp:48