29    if (root == 
nullptr || root->type == ExpressionType::kConstant) {
 
   46    while (!stack.
empty()) {
 
   47      auto currentNode = stack.
back();
 
   50      for (
auto&& 
arg : currentNode->args) {
 
   53        if (
arg != 
nullptr && 
arg->type != ExpressionType::kConstant) {
 
   56          if (
arg->duplications == 0) {
 
   66    while (!stack.
empty()) {
 
   67      auto currentNode = stack.
back();
 
   71      m_rowList.emplace_back(currentNode->row);
 
   72      m_adjointList.emplace_back(currentNode);
 
   73      if (currentNode->valueFunc != 
nullptr) {
 
   76        m_valueList.emplace_back(currentNode);
 
   79      for (
auto&& 
arg : currentNode->args) {
 
   82        if (
arg != 
nullptr && 
arg->type != ExpressionType::kConstant) {
 
   87          if (
arg->duplications == 0) {
 
 
  102    for (
auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) {
 
  105      auto& lhs = node->args[0];
 
  106      auto& rhs = node->args[1];
 
  108      if (lhs != 
nullptr) {
 
  109        if (rhs != 
nullptr) {
 
  110          node->value = node->valueFunc(lhs->value, rhs->value);
 
  112          node->value = node->valueFunc(lhs->value, 0.0);
 
 
  124      std::span<const ExpressionPtr> wrt)
 const {
 
  128    for (
size_t row = 0; row < wrt.size(); ++row) {
 
  134    for (
size_t row = 0; row < wrt.size(); ++row) {
 
  139    if (m_adjointList.size() > 0) {
 
  140      m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0);
 
  141      for (
auto it = m_adjointList.begin() + 1; it != m_adjointList.end();
 
  144        node->adjointExpr = MakeExpressionPtr();
 
  152    for (
auto node : m_adjointList) {
 
  153      auto& lhs = node->args[0];
 
  154      auto& rhs = node->args[1];
 
  156      if (lhs != 
nullptr && !lhs->IsConstant(0.0)) {
 
  157        lhs->adjointExpr = lhs->adjointExpr +
 
  158                           node->gradientFuncs[0](lhs, rhs, node->adjointExpr);
 
  160      if (rhs != 
nullptr && !rhs->IsConstant(0.0)) {
 
  161        rhs->adjointExpr = rhs->adjointExpr +
 
  162                           node->gradientFuncs[1](lhs, rhs, node->adjointExpr);
 
  166      if (node->row != -1) {
 
  167        grad[node->row] = node->adjointExpr;
 
  174    for (
auto node : m_adjointList) {
 
  175      for (
auto& 
arg : node->args) {
 
  176        if (
arg != 
nullptr) {
 
  177          arg->adjointExpr = 
nullptr;
 
  182    for (
size_t row = 0; row < wrt.size(); ++row) {
 
 
  198    m_adjointList[0]->adjoint = 1.0;
 
  199    for (
auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) {
 
  208    for (
size_t col = 0; col < m_adjointList.size(); ++col) {
 
  209      auto& node = m_adjointList[col];
 
  210      auto& lhs = node->args[0];
 
  211      auto& rhs = node->args[1];
 
  213      if (lhs != 
nullptr) {
 
  214        if (rhs != 
nullptr) {
 
  215          lhs->adjoint += node->gradientValueFuncs[0](lhs->value, rhs->value,
 
  217          rhs->adjoint += node->gradientValueFuncs[1](lhs->value, rhs->value,
 
  221              node->gradientValueFuncs[0](lhs->value, 0.0, node->adjoint);
 
  226      int row = m_rowList[col];
 
  228        func(row, node->adjoint);
 
 
 
 
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
constexpr T * Get() const noexcept
Returns the internal pointer.
Definition IntrusiveSharedPtr.hpp:111
This class is an adaptor type that performs value updates of an expression's computational graph in a...
Definition ExpressionGraph.hpp:19
void ComputeAdjoints(function_ref< void(int row, double adjoint)> func)
Updates the adjoints in the expression graph, effectively computing the gradient.
Definition ExpressionGraph.hpp:196
wpi::SmallVector< ExpressionPtr > GenerateGradientTree(std::span< const ExpressionPtr > wrt) const
Returns the variable's gradient tree.
Definition ExpressionGraph.hpp:123
ExpressionGraph(ExpressionPtr &root)
Generates the deduplicated computational graph for the given expression.
Definition ExpressionGraph.hpp:26
void Update()
Update the values of all nodes in this computational tree based on the values of their dependent node...
Definition ExpressionGraph.hpp:99
An implementation of std::function_ref, a lightweight non-owning reference to a callable.
Definition FunctionRef.hpp:17
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void reserve(size_type N)
Definition SmallVector.h:679
void pop_back()
Definition SmallVector.h:441
void push_back(const T &Elt)
Definition SmallVector.h:429
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775