WPILibC++ 2027.0.0-alpha-4
Loading...
Searching...
No Matches
expression_graph.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6
8
10
11namespace slp::detail {
12
13/// Generate a topological sort of an expression graph from parent to child.
14///
15/// https://en.wikipedia.org/wiki/Topological_sorting
16///
17/// @tparam Scalar Scalar type.
18/// @param root The root node of the expression.
19template <typename Scalar>
21 const ExpressionPtr<Scalar>& root) {
23
24 // If the root type is constant, updates are a no-op, so return an empty list
25 if (root == nullptr || root->type() == ExpressionType::CONSTANT) {
26 return list;
27 }
28
29 // Stack of nodes to explore
31
32 // Enumerate incoming edges for each node via depth-first search
33 stack.emplace_back(root.get());
34 while (!stack.empty()) {
35 auto node = stack.back();
36 stack.pop_back();
37
38 for (auto& arg : node->args) {
39 // If the node hasn't been explored yet, add it to the stack
40 if (arg != nullptr && ++arg->incoming_edges == 1) {
41 stack.push_back(arg.get());
42 }
43 }
44 }
45
46 // Generate topological sort of graph from parent to child.
47 //
48 // A node is only added to the stack after all its incoming edges have been
49 // traversed. Expression::incoming_edges is a decrementing counter for
50 // tracking this.
51 //
52 // https://en.wikipedia.org/wiki/Topological_sorting
53 stack.emplace_back(root.get());
54 while (!stack.empty()) {
55 auto node = stack.back();
56 stack.pop_back();
57
58 list.emplace_back(node);
59
60 for (auto& arg : node->args) {
61 // If we traversed all this node's incoming edges, add it to the stack
62 if (arg != nullptr && --arg->incoming_edges == 0) {
63 stack.push_back(arg.get());
64 }
65 }
66 }
67
68 return list;
69}
70
71/// Update the values of all nodes in this graph based on the values of
72/// their dependent nodes.
73///
74/// @tparam Scalar Scalar type.
75/// @param list Topological sort of graph from parent to child.
76template <typename Scalar>
78 // Traverse graph from child to parent and update values
79 for (auto& node : list | std::views::reverse) {
80 auto& lhs = node->args[0];
81 auto& rhs = node->args[1];
82
83 if (lhs != nullptr) {
84 node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0));
85 }
86 }
87}
88
89} // namespace slp::detail
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2846
constexpr T * get() const noexcept
Returns the internal pointer.
Definition intrusive_shared_ptr.hpp:178
wpi::util::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
void update_values(const gch::small_vector< Expression< Scalar > * > &list)
Update the values of all nodes in this graph based on the values of their dependent nodes.
Definition expression_graph.hpp:77
gch::small_vector< Expression< Scalar > * > topological_sort(const ExpressionPtr< Scalar > &root)
Generate a topological sort of an expression graph from parent to child.
Definition expression_graph.hpp:20
IntrusiveSharedPtr< Expression< Scalar > > ExpressionPtr
Typedef for intrusive shared pointer to Expression.
Definition expression.hpp:43
@ CONSTANT
The expression is a constant.
Definition expression_type.hpp:20
An autodiff expression node.
Definition expression.hpp:89
virtual ExpressionType type() const =0
Returns the type of this expression (constant, linear, quadratic, or nonlinear).