29 if (root ==
nullptr || root->type == ExpressionType::kConstant) {
46 while (!stack.
empty()) {
47 auto currentNode = stack.
back();
50 for (
auto&&
arg : currentNode->args) {
53 if (
arg !=
nullptr &&
arg->type != ExpressionType::kConstant) {
56 if (
arg->duplications == 0) {
66 while (!stack.
empty()) {
67 auto currentNode = stack.
back();
71 m_rowList.emplace_back(currentNode->row);
72 m_adjointList.emplace_back(currentNode);
73 if (currentNode->valueFunc !=
nullptr) {
76 m_valueList.emplace_back(currentNode);
79 for (
auto&&
arg : currentNode->args) {
82 if (
arg !=
nullptr &&
arg->type != ExpressionType::kConstant) {
87 if (
arg->duplications == 0) {
102 for (
auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) {
105 auto& lhs = node->args[0];
106 auto& rhs = node->args[1];
108 if (lhs !=
nullptr) {
109 if (rhs !=
nullptr) {
110 node->value = node->valueFunc(lhs->value, rhs->value);
112 node->value = node->valueFunc(lhs->value, 0.0);
124 std::span<const ExpressionPtr> wrt)
const {
128 for (
size_t row = 0; row < wrt.size(); ++row) {
134 for (
size_t row = 0; row < wrt.size(); ++row) {
139 if (m_adjointList.size() > 0) {
140 m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0);
141 for (
auto it = m_adjointList.begin() + 1; it != m_adjointList.end();
144 node->adjointExpr = MakeExpressionPtr();
152 for (
auto node : m_adjointList) {
153 auto& lhs = node->args[0];
154 auto& rhs = node->args[1];
156 if (lhs !=
nullptr && !lhs->IsConstant(0.0)) {
157 lhs->adjointExpr = lhs->adjointExpr +
158 node->gradientFuncs[0](lhs, rhs, node->adjointExpr);
160 if (rhs !=
nullptr && !rhs->IsConstant(0.0)) {
161 rhs->adjointExpr = rhs->adjointExpr +
162 node->gradientFuncs[1](lhs, rhs, node->adjointExpr);
166 if (node->row != -1) {
167 grad[node->row] = node->adjointExpr;
174 for (
auto node : m_adjointList) {
175 for (
auto&
arg : node->args) {
176 if (
arg !=
nullptr) {
177 arg->adjointExpr =
nullptr;
182 for (
size_t row = 0; row < wrt.size(); ++row) {
198 m_adjointList[0]->adjoint = 1.0;
199 for (
auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) {
208 for (
size_t col = 0; col < m_adjointList.size(); ++col) {
209 auto& node = m_adjointList[col];
210 auto& lhs = node->args[0];
211 auto& rhs = node->args[1];
213 if (lhs !=
nullptr) {
214 if (rhs !=
nullptr) {
215 lhs->adjoint += node->gradientValueFuncs[0](lhs->value, rhs->value,
217 rhs->adjoint += node->gradientValueFuncs[1](lhs->value, rhs->value,
221 node->gradientValueFuncs[0](lhs->value, 0.0, node->adjoint);
226 int row = m_rowList[col];
228 func(row, node->adjoint);
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
constexpr T * Get() const noexcept
Returns the internal pointer.
Definition IntrusiveSharedPtr.hpp:111
This class is an adaptor type that performs value updates of an expression's computational graph in a...
Definition ExpressionGraph.hpp:19
void ComputeAdjoints(function_ref< void(int row, double adjoint)> func)
Updates the adjoints in the expression graph, effectively computing the gradient.
Definition ExpressionGraph.hpp:196
wpi::SmallVector< ExpressionPtr > GenerateGradientTree(std::span< const ExpressionPtr > wrt) const
Returns the variable's gradient tree.
Definition ExpressionGraph.hpp:123
ExpressionGraph(ExpressionPtr &root)
Generates the deduplicated computational graph for the given expression.
Definition ExpressionGraph.hpp:26
void Update()
Update the values of all nodes in this computational tree based on the values of their dependent node...
Definition ExpressionGraph.hpp:99
An implementation of std::function_ref, a lightweight non-owning reference to a callable.
Definition FunctionRef.hpp:17
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void reserve(size_type N)
Definition SmallVector.h:679
void pop_back()
Definition SmallVector.h:441
void push_back(const T &Elt)
Definition SmallVector.h:429
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775