WPILibC++ 2027.0.0-alpha-4
Loading...
Searching...
No Matches
gradient_expression_graph.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
10
16
17namespace slp::detail {
18
19/// This class is an adapter type that performs value updates of an expression
20/// graph, generates a gradient tree, or appends gradient triplets for creating
21/// a sparse matrix of gradients.
22///
23/// @tparam Scalar Scalar type.
24template <typename Scalar>
26 public:
27 /// Generates the gradient graph for the given expression.
28 ///
29 /// @param root The root node of the expression.
31 : m_top_list{topological_sort(root.expr)} {
32 for (const auto& node : m_top_list) {
33 m_col_list.emplace_back(node->col);
34 }
35 }
36
37 /// Update the values of all nodes in this graph based on the values of their
38 /// dependent nodes.
39 void update_values() { detail::update_values(m_top_list); }
40
41 /// Returns the variable's gradient tree.
42 ///
43 /// This function lazily allocates variables, so elements of the returned
44 /// VariableMatrix will be empty if the corresponding element of wrt had no
45 /// adjoint. Ensure Variable::expr isn't nullptr before calling member
46 /// functions.
47 ///
48 /// @param wrt Variables with respect to which to compute the gradient.
49 /// @return The variable's gradient tree.
51 const VariableMatrix<Scalar>& wrt) const {
52 slp_assert(wrt.cols() == 1);
53
54 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
55 // for background on reverse accumulation automatic differentiation.
56
57 if (m_top_list.empty()) {
59 }
60
61 // Set root node's adjoint to 1 since df/df is 1
62 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
63
64 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
65 // multiplied by dy/dx. If there are multiple "paths" from the root node to
66 // variable; the variable's adjoint is the sum of each path's adjoint
67 // contribution.
68 for (auto& node : m_top_list) {
69 auto& lhs = node->args[0];
70 auto& rhs = node->args[1];
71
72 if (lhs != nullptr) {
73 if (rhs != nullptr) {
74 // Binary operator
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
76 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
77 } else {
78 // Unary operator
79 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
80 }
81 }
82 }
83
84 // Move gradient tree to return value
86 for (int row = 0; row < grad.rows(); ++row) {
87 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
88 }
89
90 // Unlink adjoints to avoid circular references between them and their
91 // parent expressions. This ensures all expressions are returned to the free
92 // list.
93 for (auto& node : m_top_list) {
94 node->adjoint_expr = nullptr;
95 }
96
97 return grad;
98 }
99
100 /// Updates the adjoints in the expression graph (computes the gradient) then
101 /// appends the adjoints of wrt to the sparse matrix triplets.
102 ///
103 /// @param triplets The sparse matrix triplets.
104 /// @param row The row of wrt.
105 /// @param wrt Vector of variables with respect to which to compute the
106 /// Jacobian.
107 void append_triplets(gch::small_vector<Eigen::Triplet<Scalar>>& triplets,
108 int row, const VariableMatrix<Scalar>& wrt) const {
109 slp_assert(wrt.cols() == 1);
110
111 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
112 // for background on reverse accumulation automatic differentiation.
113
114 // If wrt has fewer nodes than graph, zero wrt's adjoints
115 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
116 for (const auto& elem : wrt) {
117 elem.expr->adjoint = Scalar(0);
118 }
119 }
120
121 if (m_top_list.empty()) {
122 return;
123 }
124
125 // Set root node's adjoint to 1 since df/df is 1
126 m_top_list[0]->adjoint = Scalar(1);
127
128 // Zero the rest of the adjoints
129 for (auto& node : m_top_list | std::views::drop(1)) {
130 node->adjoint = Scalar(0);
131 }
132
133 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
134 // multiplied by dy/dx. If there are multiple "paths" from the root node to
135 // variable; the variable's adjoint is the sum of each path's adjoint
136 // contribution.
137 for (const auto& node : m_top_list) {
138 auto& lhs = node->args[0];
139 auto& rhs = node->args[1];
140
141 if (lhs != nullptr) {
142 if (rhs != nullptr) {
143 // Binary operator
144 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
145 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
146 } else {
147 // Unary operator
148 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
149 }
150 }
151 }
152
153 // If wrt has fewer nodes than graph, iterate over wrt
154 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
155 for (int col = 0; col < wrt.rows(); ++col) {
156 const auto& node = wrt[col].expr;
157
158 // Append adjoints of wrt to sparse matrix triplets
159 if (node->adjoint != Scalar(0)) {
160 triplets.emplace_back(row, col, node->adjoint);
161 }
162 }
163 } else {
164 for (size_t i = 0; i < m_top_list.size(); ++i) {
165 const auto& col = m_col_list[i];
166 const auto& node = m_top_list[i];
167
168 // Append adjoints of wrt to sparse matrix triplets
169 if (col != -1 && node->adjoint != Scalar(0)) {
170 triplets.emplace_back(row, col, node->adjoint);
171 }
172 }
173 }
174 }
175
176 private:
177 // Topological sort of graph from parent to child
179
180 // List that maps nodes to their respective column
181 gch::small_vector<int> m_col_list;
182};
183
184} // namespace slp::detail
#define slp_assert(condition)
Abort in C++.
Definition assert.hpp:25
An autodiff variable pointing to an expression node.
Definition variable.hpp:47
A matrix of autodiff variables.
Definition variable_matrix.hpp:33
int rows() const
Returns the number of rows in the matrix.
Definition variable_matrix.hpp:972
int cols() const
Returns the number of columns in the matrix.
Definition variable_matrix.hpp:977
GradientExpressionGraph(const Variable< Scalar > &root)
Generates the gradient graph for the given expression.
Definition gradient_expression_graph.hpp:30
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row, const VariableMatrix< Scalar > &wrt) const
Updates the adjoints in the expression graph (computes the gradient) then appends the adjoints of wrt...
Definition gradient_expression_graph.hpp:107
void update_values()
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Returns the variable's gradient tree.
Definition gradient_expression_graph.hpp:50
wpi::util::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
static constexpr empty_t empty
Designates an uninitialized VariableMatrix.
Definition empty.hpp:11
void update_values(const gch::small_vector< Expression< Scalar > * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:77
ExpressionPtr< Scalar > constant_ptr(Scalar value)
Creates an intrusive shared pointer to a constant expression.
Definition expression.hpp:417
gch::small_vector< Expression< Scalar > * > topological_sort(const ExpressionPtr< Scalar > &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20