WPILibC++ 2025.1.1
Loading...
Searching...
No Matches
ExpressionGraph.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <span>
6
7#include <wpi/SmallVector.h>
8
12
14
15/**
16 * This class is an adaptor type that performs value updates of an expression's
17 * computational graph in a way that skips duplicates.
18 */
20 public:
21 /**
22 * Generates the deduplicated computational graph for the given expression.
23 *
24 * @param root The root node of the expression.
25 */
27 // If the root type is a constant, Update() is a no-op, so there's no work
28 // to do
29 if (root == nullptr || root->type == ExpressionType::kConstant) {
30 return;
31 }
32
33 // Breadth-first search (BFS) is used as opposed to a depth-first search
34 // (DFS) to avoid counting duplicate nodes multiple times. A list of nodes
35 // ordered from parent to child with no duplicates is generated.
36 //
37 // https://en.wikipedia.org/wiki/Breadth-first_search
38
39 // BFS list sorted from parent to child.
41
42 stack.emplace_back(root.Get());
43
44 // Initialize the number of instances of each node in the tree
45 // (Expression::duplications)
46 while (!stack.empty()) {
47 auto currentNode = stack.back();
48 stack.pop_back();
49
50 for (auto&& arg : currentNode->args) {
51 // Only continue if the node is not a constant and hasn't already been
52 // explored.
53 if (arg != nullptr && arg->type != ExpressionType::kConstant) {
54 // If this is the first instance of the node encountered (it hasn't
55 // been explored yet), add it to stack so it's recursed upon
56 if (arg->duplications == 0) {
57 stack.push_back(arg.Get());
58 }
59 ++arg->duplications;
60 }
61 }
62 }
63
64 stack.emplace_back(root.Get());
65
66 while (!stack.empty()) {
67 auto currentNode = stack.back();
68 stack.pop_back();
69
70 // BFS lists sorted from parent to child.
71 m_rowList.emplace_back(currentNode->row);
72 m_adjointList.emplace_back(currentNode);
73 if (currentNode->valueFunc != nullptr) {
74 // Constants are skipped because they have no valueFunc and don't need
75 // to be updated
76 m_valueList.emplace_back(currentNode);
77 }
78
79 for (auto&& arg : currentNode->args) {
80 // Only add node if it's not a constant and doesn't already exist in the
81 // tape.
82 if (arg != nullptr && arg->type != ExpressionType::kConstant) {
83 // Once the number of node visitations equals the number of
84 // duplications (the counter hits zero), add it to the stack. Note
85 // that this means the node is only enqueued once.
86 --arg->duplications;
87 if (arg->duplications == 0) {
88 stack.push_back(arg.Get());
89 }
90 }
91 }
92 }
93 }
94
95 /**
96 * Update the values of all nodes in this computational tree based on the
97 * values of their dependent nodes.
98 */
99 void Update() {
100 // Traverse the BFS list backward from child to parent and update the value
101 // of each node.
102 for (auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) {
103 auto& node = *it;
104
105 auto& lhs = node->args[0];
106 auto& rhs = node->args[1];
107
108 if (lhs != nullptr) {
109 if (rhs != nullptr) {
110 node->value = node->valueFunc(lhs->value, rhs->value);
111 } else {
112 node->value = node->valueFunc(lhs->value, 0.0);
113 }
114 }
115 }
116 }
117
118 /**
119 * Returns the variable's gradient tree.
120 *
121 * @param wrt Variables with respect to which to compute the gradient.
122 */
124 std::span<const ExpressionPtr> wrt) const {
125 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
126 // for background on reverse accumulation automatic differentiation.
127
128 for (size_t row = 0; row < wrt.size(); ++row) {
129 wrt[row]->row = row;
130 }
131
133 grad.reserve(wrt.size());
134 for (size_t row = 0; row < wrt.size(); ++row) {
135 grad.emplace_back(MakeExpressionPtr());
136 }
137
138 // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
139 if (m_adjointList.size() > 0) {
140 m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0);
141 for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end();
142 ++it) {
143 auto& node = *it;
144 node->adjointExpr = MakeExpressionPtr();
145 }
146 }
147
148 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
149 // multiplied by dy/dx. If there are multiple "paths" from the root node to
150 // variable; the variable's adjoint is the sum of each path's adjoint
151 // contribution.
152 for (auto node : m_adjointList) {
153 auto& lhs = node->args[0];
154 auto& rhs = node->args[1];
155
156 if (lhs != nullptr && !lhs->IsConstant(0.0)) {
157 lhs->adjointExpr = lhs->adjointExpr +
158 node->gradientFuncs[0](lhs, rhs, node->adjointExpr);
159 }
160 if (rhs != nullptr && !rhs->IsConstant(0.0)) {
161 rhs->adjointExpr = rhs->adjointExpr +
162 node->gradientFuncs[1](lhs, rhs, node->adjointExpr);
163 }
164
165 // If variable is a leaf node, assign its adjoint to the gradient.
166 if (node->row != -1) {
167 grad[node->row] = node->adjointExpr;
168 }
169 }
170
171 // Unlink adjoints to avoid circular references between them and their
172 // parent expressions. This ensures all expressions are returned to the free
173 // list.
174 for (auto node : m_adjointList) {
175 for (auto& arg : node->args) {
176 if (arg != nullptr) {
177 arg->adjointExpr = nullptr;
178 }
179 }
180 }
181
182 for (size_t row = 0; row < wrt.size(); ++row) {
183 wrt[row]->row = -1;
184 }
185
186 return grad;
187 }
188
189 /**
190 * Updates the adjoints in the expression graph, effectively computing the
191 * gradient.
192 *
193 * @param func A function that takes two arguments: an int for the gradient
194 * row, and a double for the adjoint (gradient value).
195 */
196 void ComputeAdjoints(function_ref<void(int row, double adjoint)> func) {
197 // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
198 m_adjointList[0]->adjoint = 1.0;
199 for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) {
200 auto& node = *it;
201 node->adjoint = 0.0;
202 }
203
204 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
205 // multiplied by dy/dx. If there are multiple "paths" from the root node to
206 // variable; the variable's adjoint is the sum of each path's adjoint
207 // contribution.
208 for (size_t col = 0; col < m_adjointList.size(); ++col) {
209 auto& node = m_adjointList[col];
210 auto& lhs = node->args[0];
211 auto& rhs = node->args[1];
212
213 if (lhs != nullptr) {
214 if (rhs != nullptr) {
215 lhs->adjoint += node->gradientValueFuncs[0](lhs->value, rhs->value,
216 node->adjoint);
217 rhs->adjoint += node->gradientValueFuncs[1](lhs->value, rhs->value,
218 node->adjoint);
219 } else {
220 lhs->adjoint +=
221 node->gradientValueFuncs[0](lhs->value, 0.0, node->adjoint);
222 }
223 }
224
225 // If variable is a leaf node, assign its adjoint to the gradient.
226 int row = m_rowList[col];
227 if (row != -1) {
228 func(row, node->adjoint);
229 }
230 }
231 }
232
233 private:
234 // List that maps nodes to their respective row.
235 wpi::SmallVector<int> m_rowList;
236
237 // List for updating adjoints
238 wpi::SmallVector<Expression*> m_adjointList;
239
240 // List for updating values
242};
243
244} // namespace sleipnir::detail
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
constexpr T * Get() const noexcept
Returns the internal pointer.
Definition IntrusiveSharedPtr.hpp:111
This class is an adaptor type that performs value updates of an expression's computational graph in a...
Definition ExpressionGraph.hpp:19
void ComputeAdjoints(function_ref< void(int row, double adjoint)> func)
Updates the adjoints in the expression graph, effectively computing the gradient.
Definition ExpressionGraph.hpp:196
wpi::SmallVector< ExpressionPtr > GenerateGradientTree(std::span< const ExpressionPtr > wrt) const
Returns the variable's gradient tree.
Definition ExpressionGraph.hpp:123
ExpressionGraph(ExpressionPtr &root)
Generates the deduplicated computational graph for the given expression.
Definition ExpressionGraph.hpp:26
void Update()
Update the values of all nodes in this computational tree based on the values of their dependent node...
Definition ExpressionGraph.hpp:99
An implementation of std::function_ref, a lightweight non-owning reference to a callable.
Definition FunctionRef.hpp:17
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void reserve(size_type N)
Definition SmallVector.h:679
void pop_back()
Definition SmallVector.h:441
void push_back(const T &Elt)
Definition SmallVector.h:429
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775