WPILibC++ 2027.0.0-alpha-4
Loading...
Searching...
No Matches
expression.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <string_view>
13#include <utility>
14
15#include <gch/small_vector.hpp>
16
20
21namespace slp::detail {
22
23// The global pool allocator uses a thread-local static pool resource, which
24// isn't guaranteed to be initialized properly across DLL boundaries on Windows
25#ifdef _WIN32
26inline constexpr bool USE_POOL_ALLOCATOR = false;
27#else
28inline constexpr bool USE_POOL_ALLOCATOR = true;
29#endif
30
31template <typename Scalar>
32struct Expression;
33
34template <typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <typename Scalar>
38
39/// Typedef for intrusive shared pointer to Expression.
40///
41/// @tparam Scalar Scalar type.
42template <typename Scalar>
44
45/// Creates an intrusive shared pointer to an expression from the global pool
46/// allocator.
47///
48/// @tparam T The derived expression type.
49/// @param args Constructor arguments for Expression.
50template <typename T, typename... Args>
52 if constexpr (USE_POOL_ALLOCATOR) {
54 std::forward<Args>(args)...);
55 } else {
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57 }
58}
59
60template <typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
62
63template <typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
65
66template <typename Scalar>
67struct ConstantExpression;
68
69template <typename Scalar, ExpressionType T>
70struct DivExpression;
71
72template <typename Scalar, ExpressionType T>
73struct MultExpression;
74
75template <typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
77
78/// Creates an intrusive shared pointer to a constant expression.
79///
80/// @tparam Scalar Scalar type.
81/// @param value The expression value.
82template <typename Scalar>
84
85/// An autodiff expression node.
86///
87/// @tparam Scalar Scalar type.
88template <typename Scalar_>
89struct Expression {
90 /// Scalar type alias.
91 using Scalar = Scalar_;
92
93 /// The value of the expression node.
95
96 /// The adjoint of the expression node, used during autodiff.
98
99 /// Counts incoming edges for this node.
100 uint32_t incoming_edges = 0;
101
102 /// This expression's column in a Jacobian, or -1 otherwise.
103 int32_t col = -1;
104
105 /// The adjoint of the expression node, used during gradient expression tree
106 /// generation.
108
109 /// Reference count for intrusive shared pointer.
110 uint32_t ref_count = 0;
111
112 /// Expression arguments.
113 std::array<ExpressionPtr<Scalar>, 2> args{nullptr, nullptr};
114
115 /// Constructs a constant expression with a value of zero.
116 constexpr Expression() = default;
117
118 /// Constructs a nullary expression (an operator with no arguments).
119 ///
120 /// @param value The expression value.
121 explicit constexpr Expression(Scalar value) : val{value} {}
122
123 /// Constructs an unary expression (an operator with one argument).
124 ///
125 /// @param lhs Unary operator's operand.
126 explicit constexpr Expression(ExpressionPtr<Scalar> lhs)
127 : args{std::move(lhs), nullptr} {}
128
129 /// Constructs a binary expression (an operator with two arguments).
130 ///
131 /// @param lhs Binary operator's left operand.
132 /// @param rhs Binary operator's right operand.
134 : args{std::move(lhs), std::move(rhs)} {}
135
136 virtual ~Expression() = default;
137
138 /// Returns true if the expression is the given constant.
139 ///
140 /// @param constant The constant.
141 /// @return True if the expression is the given constant.
142 constexpr bool is_constant(Scalar constant) const {
143 return type() == ExpressionType::CONSTANT && val == constant;
144 }
145
146 /// Expression-Expression multiplication operator.
147 ///
148 /// @param lhs Operator left-hand side.
149 /// @param rhs Operator right-hand side.
151 const ExpressionPtr<Scalar>& rhs) {
152 using enum ExpressionType;
153
154 // Prune expression
155 if (lhs->is_constant(Scalar(0))) {
156 // Return zero
157 return lhs;
158 } else if (rhs->is_constant(Scalar(0))) {
159 // Return zero
160 return rhs;
161 } else if (lhs->is_constant(Scalar(1))) {
162 return rhs;
163 } else if (rhs->is_constant(Scalar(1))) {
164 return lhs;
165 }
166
167 // Evaluate constant
168 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
169 return constant_ptr(lhs->val * rhs->val);
170 }
171
172 // Evaluate expression type
173 if (lhs->type() == CONSTANT) {
174 if (rhs->type() == LINEAR) {
176 } else if (rhs->type() == QUADRATIC) {
178 } else {
180 }
181 } else if (rhs->type() == CONSTANT) {
182 if (lhs->type() == LINEAR) {
184 } else if (lhs->type() == QUADRATIC) {
186 } else {
188 }
189 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
191 } else {
193 }
194 }
195
196 /// Expression-Expression division operator.
197 ///
198 /// @param lhs Operator left-hand side.
199 /// @param rhs Operator right-hand side.
201 const ExpressionPtr<Scalar>& rhs) {
202 using enum ExpressionType;
203
204 // Prune expression
205 if (lhs->is_constant(Scalar(0))) {
206 // Return zero
207 return lhs;
208 } else if (rhs->is_constant(Scalar(1))) {
209 return lhs;
210 }
211
212 // Evaluate constant
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return constant_ptr(lhs->val / rhs->val);
215 }
216
217 // Evaluate expression type
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
221 } else if (lhs->type() == QUADRATIC) {
223 } else {
225 }
226 } else {
228 }
229 }
230
231 /// Expression-Expression addition operator.
232 ///
233 /// @param lhs Operator left-hand side.
234 /// @param rhs Operator right-hand side.
236 const ExpressionPtr<Scalar>& rhs) {
237 using enum ExpressionType;
238
239 // Prune expression
240 if (lhs == nullptr || lhs->is_constant(Scalar(0))) {
241 return rhs;
242 } else if (rhs == nullptr || rhs->is_constant(Scalar(0))) {
243 return lhs;
244 }
245
246 // Evaluate constant
247 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
248 return constant_ptr(lhs->val + rhs->val);
249 }
250
251 auto type = std::max(lhs->type(), rhs->type());
252 if (type == LINEAR) {
254 rhs);
255 } else if (type == QUADRATIC) {
257 rhs);
258 } else {
260 rhs);
261 }
262 }
263
264 /// Expression-Expression compound addition operator.
265 ///
266 /// @param lhs Operator left-hand side.
267 /// @param rhs Operator right-hand side.
269 const ExpressionPtr<Scalar>& rhs) {
270 return lhs = lhs + rhs;
271 }
272
273 /// Expression-Expression subtraction operator.
274 ///
275 /// @param lhs Operator left-hand side.
276 /// @param rhs Operator right-hand side.
278 const ExpressionPtr<Scalar>& rhs) {
279 using enum ExpressionType;
280
281 // Prune expression
282 if (lhs->is_constant(Scalar(0))) {
283 if (rhs->is_constant(Scalar(0))) {
284 // Return zero
285 return rhs;
286 } else {
287 return -rhs;
288 }
289 } else if (rhs->is_constant(Scalar(0))) {
290 return lhs;
291 }
292
293 // Evaluate constant
294 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
295 return constant_ptr(lhs->val - rhs->val);
296 }
297
298 auto type = std::max(lhs->type(), rhs->type());
299 if (type == LINEAR) {
301 rhs);
302 } else if (type == QUADRATIC) {
304 rhs);
305 } else {
307 rhs);
308 }
309 }
310
311 /// Unary minus operator.
312 ///
313 /// @param lhs Operand of unary minus.
315 using enum ExpressionType;
316
317 // Prune expression
318 if (lhs->is_constant(Scalar(0))) {
319 // Return zero
320 return lhs;
321 }
322
323 // Evaluate constant
324 if (lhs->type() == CONSTANT) {
325 return constant_ptr(-lhs->val);
326 }
327
328 if (lhs->type() == LINEAR) {
330 } else if (lhs->type() == QUADRATIC) {
332 } else {
334 }
335 }
336
337 /// Unary plus operator.
338 ///
339 /// @param lhs Operand of unary plus.
341 return lhs;
342 }
343
344 /// Either nullary operator with no arguments, unary operator with one
345 /// argument, or binary operator with two arguments. This operator is used to
346 /// update the node's value.
347 ///
348 /// @param lhs Left argument to binary operator.
349 /// @param rhs Right argument to binary operator.
350 /// @return The node's value.
351 virtual Scalar value([[maybe_unused]] Scalar lhs,
352 [[maybe_unused]] Scalar rhs) const = 0;
353
354 /// Returns the type of this expression (constant, linear, quadratic, or
355 /// nonlinear).
356 ///
357 /// @return The type of this expression.
358 virtual ExpressionType type() const = 0;
359
360 /// Returns the name of this expression.
361 ///
362 /// @return The name of this expression.
363 virtual std::string_view name() const = 0;
364
365 /// Returns ∂/∂l as a Scalar.
366 ///
367 /// @param lhs Left argument to binary operator.
368 /// @param rhs Right argument to binary operator.
369 /// @param parent_adjoint Adjoint of parent expression.
370 /// @return ∂/∂l as a Scalar.
371 virtual Scalar grad_l([[maybe_unused]] Scalar lhs,
372 [[maybe_unused]] Scalar rhs,
373 [[maybe_unused]] Scalar parent_adjoint) const {
374 return Scalar(0);
375 }
376
377 /// Returns ∂/∂r as a Scalar.
378 ///
379 /// @param lhs Left argument to binary operator.
380 /// @param rhs Right argument to binary operator.
381 /// @param parent_adjoint Adjoint of parent expression.
382 /// @return ∂/∂r as a Scalar.
383 virtual Scalar grad_r([[maybe_unused]] Scalar lhs,
384 [[maybe_unused]] Scalar rhs,
385 [[maybe_unused]] Scalar parent_adjoint) const {
386 return Scalar(0);
387 }
388
389 /// Returns ∂/∂l as an Expression.
390 ///
391 /// @param lhs Left argument to binary operator.
392 /// @param rhs Right argument to binary operator.
393 /// @param parent_adjoint Adjoint of parent expression.
394 /// @return ∂/∂l as an Expression.
396 [[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
397 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
398 [[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
399 return constant_ptr(Scalar(0));
400 }
401
402 /// Returns ∂/∂r as an Expression.
403 ///
404 /// @param lhs Left argument to binary operator.
405 /// @param rhs Right argument to binary operator.
406 /// @param parent_adjoint Adjoint of parent expression.
407 /// @return ∂/∂r as an Expression.
409 [[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
410 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
411 [[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
412 return constant_ptr(Scalar(0));
413 }
414};
415
416template <typename Scalar>
420
421template <typename Scalar>
423template <typename Scalar>
425template <typename Scalar>
427template <typename Scalar>
429template <typename Scalar>
431
432/// Derived expression type for binary minus operator.
433///
434/// @tparam Scalar Scalar type.
435/// @tparam T Expression type.
436template <typename Scalar, ExpressionType T>
437struct BinaryMinusExpression final : Expression<Scalar> {
438 /// Constructs a binary expression (an operator with two arguments).
439 ///
440 /// @param lhs Binary operator's left operand.
441 /// @param rhs Binary operator's right operand.
444 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
445
446 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs - rhs; }
447
448 ExpressionType type() const override { return T; }
449
450 std::string_view name() const override { return "binary minus"; }
451
452 Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
453 return parent_adjoint;
454 }
455
456 Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
457 return -parent_adjoint;
458 }
459
462 const ExpressionPtr<Scalar>& parent_adjoint) const override {
463 return parent_adjoint;
464 }
465
468 const ExpressionPtr<Scalar>& parent_adjoint) const override {
469 return -parent_adjoint;
470 }
471};
472
473/// Derived expression type for binary plus operator.
474///
475/// @tparam Scalar Scalar type.
476/// @tparam T Expression type.
477template <typename Scalar, ExpressionType T>
478struct BinaryPlusExpression final : Expression<Scalar> {
479 /// Constructs a binary expression (an operator with two arguments).
480 ///
481 /// @param lhs Binary operator's left operand.
482 /// @param rhs Binary operator's right operand.
485 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
486
487 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs + rhs; }
488
489 ExpressionType type() const override { return T; }
490
491 std::string_view name() const override { return "binary plus"; }
492
493 Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
494 return parent_adjoint;
495 }
496
497 Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
498 return parent_adjoint;
499 }
500
503 const ExpressionPtr<Scalar>& parent_adjoint) const override {
504 return parent_adjoint;
505 }
506
509 const ExpressionPtr<Scalar>& parent_adjoint) const override {
510 return parent_adjoint;
511 }
512};
513
514/// Derived expression type for cbrt().
515///
516/// @tparam Scalar Scalar type.
517template <typename Scalar>
518struct CbrtExpression final : Expression<Scalar> {
519 /// Constructs an unary expression (an operator with one argument).
520 ///
521 /// @param lhs Unary operator's operand.
522 explicit constexpr CbrtExpression(ExpressionPtr<Scalar> lhs)
523 : Expression<Scalar>{std::move(lhs)} {}
524
525 Scalar value(Scalar x, Scalar) const override {
526 using std::cbrt;
527 return cbrt(x);
528 }
529
530 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
531
532 std::string_view name() const override { return "cbrt"; }
533
534 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
535 using std::cbrt;
536
537 Scalar c = cbrt(x);
538 return parent_adjoint / (Scalar(3) * c * c);
539 }
540
543 const ExpressionPtr<Scalar>& parent_adjoint) const override {
544 auto c = cbrt(x);
545 return parent_adjoint / (constant_ptr(Scalar(3)) * c * c);
546 }
547};
548
549/// cbrt() for Expressions.
550///
551/// @tparam Scalar Scalar type.
552/// @param x The argument.
553template <typename Scalar>
555 using enum ExpressionType;
556 using std::cbrt;
557
558 // Evaluate constant
559 if (x->type() == CONSTANT) {
560 if (x->val == Scalar(0)) {
561 // Return zero
562 return x;
563 } else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
564 return x;
565 } else {
566 return constant_ptr(cbrt(x->val));
567 }
568 }
569
571}
572
573/// Derived expression type for constant.
574///
575/// @tparam Scalar Scalar type.
576template <typename Scalar>
577struct ConstantExpression final : Expression<Scalar> {
578 /// Constructs a nullary expression (an operator with no arguments).
579 ///
580 /// @param value The expression value.
581 explicit constexpr ConstantExpression(Scalar value)
582 : Expression<Scalar>{value} {}
583
584 Scalar value(Scalar, Scalar) const override { return this->val; }
585
586 ExpressionType type() const override { return ExpressionType::CONSTANT; }
587
588 std::string_view name() const override { return "constant"; }
589};
590
591/// Derived expression type for decision variable.
592///
593/// @tparam Scalar Scalar type.
594template <typename Scalar>
596 /// Constructs a decision variable expression with a value of zero.
597 constexpr DecisionVariableExpression() = default;
598
599 /// Constructs a nullary expression (an operator with no arguments).
600 ///
601 /// @param value The expression value.
604
605 Scalar value(Scalar, Scalar) const override { return this->val; }
606
607 ExpressionType type() const override { return ExpressionType::LINEAR; }
608
609 std::string_view name() const override { return "decision variable"; }
610};
611
612/// Derived expression type for binary division operator.
613///
614/// @tparam Scalar Scalar type.
615/// @tparam T Expression type.
616template <typename Scalar, ExpressionType T>
617struct DivExpression final : Expression<Scalar> {
618 /// Constructs a binary expression (an operator with two arguments).
619 ///
620 /// @param lhs Binary operator's left operand.
621 /// @param rhs Binary operator's right operand.
623 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
624
625 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs / rhs; }
626
627 ExpressionType type() const override { return T; }
628
629 std::string_view name() const override { return "division"; }
630
631 Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override {
632 return parent_adjoint / rhs;
633 };
634
635 Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override {
636 return parent_adjoint * -lhs / (rhs * rhs);
637 }
638
641 const ExpressionPtr<Scalar>& parent_adjoint) const override {
642 return parent_adjoint / rhs;
643 }
644
646 const ExpressionPtr<Scalar>& lhs, const ExpressionPtr<Scalar>& rhs,
647 const ExpressionPtr<Scalar>& parent_adjoint) const override {
648 return parent_adjoint * -lhs / (rhs * rhs);
649 }
650};
651
652/// Derived expression type for binary multiplication operator.
653///
654/// @tparam Scalar Scalar type.
655/// @tparam T Expression type.
656template <typename Scalar, ExpressionType T>
657struct MultExpression final : Expression<Scalar> {
658 /// Constructs a binary expression (an operator with two arguments).
659 ///
660 /// @param lhs Binary operator's left operand.
661 /// @param rhs Binary operator's right operand.
663 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
664
665 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs * rhs; }
666
667 ExpressionType type() const override { return T; }
668
669 std::string_view name() const override { return "multiplication"; }
670
671 Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs,
672 Scalar parent_adjoint) const override {
673 return parent_adjoint * rhs;
674 }
675
676 Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs,
677 Scalar parent_adjoint) const override {
678 return parent_adjoint * lhs;
679 }
680
682 [[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
683 const ExpressionPtr<Scalar>& rhs,
684 const ExpressionPtr<Scalar>& parent_adjoint) const override {
685 return parent_adjoint * rhs;
686 }
687
689 const ExpressionPtr<Scalar>& lhs,
690 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
691 const ExpressionPtr<Scalar>& parent_adjoint) const override {
692 return parent_adjoint * lhs;
693 }
694};
695
696/// Derived expression type for unary minus operator.
697///
698/// @tparam Scalar Scalar type.
699/// @tparam T Expression type.
700template <typename Scalar, ExpressionType T>
701struct UnaryMinusExpression final : Expression<Scalar> {
702 /// Constructs an unary expression (an operator with one argument).
703 ///
704 /// @param lhs Unary operator's operand.
706 : Expression<Scalar>{std::move(lhs)} {}
707
708 Scalar value(Scalar lhs, Scalar) const override { return -lhs; }
709
710 ExpressionType type() const override { return T; }
711
712 std::string_view name() const override { return "unary minus"; }
713
714 Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
715 return -parent_adjoint;
716 }
717
720 const ExpressionPtr<Scalar>& parent_adjoint) const override {
721 return -parent_adjoint;
722 }
723};
724
725/// Refcount increment for intrusive shared pointer.
726///
727/// @tparam Scalar Scalar type.
728/// @param expr The shared pointer's managed object.
729template <typename Scalar>
730constexpr void inc_ref_count(Expression<Scalar>* expr) {
731 ++expr->ref_count;
732}
733
734/// Refcount decrement for intrusive shared pointer.
735///
736/// @tparam Scalar Scalar type.
737/// @param expr The shared pointer's managed object.
738template <typename Scalar>
740 // If a deeply nested tree is being deallocated all at once, calling the
741 // Expression destructor when expr's refcount reaches zero can cause a stack
742 // overflow. Instead, we iterate over its children to decrement their
743 // refcounts and deallocate them.
745 stack.emplace_back(expr);
746
747 while (!stack.empty()) {
748 auto elem = stack.back();
749 stack.pop_back();
750
751 // Decrement the current node's refcount. If it reaches zero, deallocate the
752 // node and enqueue its children so their refcounts are decremented too.
753 if (--elem->ref_count == 0) {
754 if (elem->adjoint_expr != nullptr) {
755 stack.emplace_back(elem->adjoint_expr.get());
756 }
757 for (auto& arg : elem->args) {
758 if (arg != nullptr) {
759 stack.emplace_back(arg.get());
760 }
761 }
762
763 // Not calling the destructor here is safe because it only decrements
764 // refcounts, which was already done above.
765 if constexpr (USE_POOL_ALLOCATOR) {
767 std::allocator_traits<decltype(alloc)>::deallocate(
768 alloc, elem, sizeof(Expression<Scalar>));
769 }
770 }
771 }
772}
773
774/// Derived expression type for abs().
775///
776/// @tparam Scalar Scalar type.
777template <typename Scalar>
778struct AbsExpression final : Expression<Scalar> {
779 /// Constructs an unary expression (an operator with one argument).
780 ///
781 /// @param lhs Unary operator's operand.
782 explicit constexpr AbsExpression(ExpressionPtr<Scalar> lhs)
783 : Expression<Scalar>{std::move(lhs)} {}
784
785 Scalar value(Scalar x, Scalar) const override {
786 using std::abs;
787 return abs(x);
788 }
789
790 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
791
792 std::string_view name() const override { return "abs"; }
793
794 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
795 if (x < Scalar(0)) {
796 return -parent_adjoint;
797 } else if (x > Scalar(0)) {
798 return parent_adjoint;
799 } else {
800 return Scalar(0);
801 }
802 }
803
806 const ExpressionPtr<Scalar>& parent_adjoint) const override {
807 if (x->val < Scalar(0)) {
808 return -parent_adjoint;
809 } else if (x->val > Scalar(0)) {
810 return parent_adjoint;
811 } else {
812 return constant_ptr(Scalar(0));
813 }
814 }
815};
816
817/// abs() for Expressions.
818///
819/// @tparam Scalar Scalar type.
820/// @param x The argument.
821template <typename Scalar>
823 using enum ExpressionType;
824 using std::abs;
825
826 // Prune expression
827 if (x->is_constant(Scalar(0))) {
828 // Return zero
829 return x;
830 }
831
832 // Evaluate constant
833 if (x->type() == CONSTANT) {
834 return constant_ptr(abs(x->val));
835 }
836
838}
839
840/// Derived expression type for acos().
841///
842/// @tparam Scalar Scalar type.
843template <typename Scalar>
844struct AcosExpression final : Expression<Scalar> {
845 /// Constructs an unary expression (an operator with one argument).
846 ///
847 /// @param lhs Unary operator's operand.
848 explicit constexpr AcosExpression(ExpressionPtr<Scalar> lhs)
849 : Expression<Scalar>{std::move(lhs)} {}
850
851 Scalar value(Scalar x, Scalar) const override {
852 using std::acos;
853 return acos(x);
854 }
855
856 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
857
858 std::string_view name() const override { return "acos"; }
859
860 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
861 using std::sqrt;
862 return -parent_adjoint / sqrt(Scalar(1) - x * x);
863 }
864
867 const ExpressionPtr<Scalar>& parent_adjoint) const override {
868 return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
869 }
870};
871
872/// acos() for Expressions.
873///
874/// @tparam Scalar Scalar type.
875/// @param x The argument.
876template <typename Scalar>
878 using enum ExpressionType;
879 using std::acos;
880
881 // Prune expression
882 if (x->is_constant(Scalar(0))) {
883 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
884 }
885
886 // Evaluate constant
887 if (x->type() == CONSTANT) {
888 return constant_ptr(acos(x->val));
889 }
890
892}
893
894/// Derived expression type for asin().
895///
896/// @tparam Scalar Scalar type.
897template <typename Scalar>
898struct AsinExpression final : Expression<Scalar> {
899 /// Constructs an unary expression (an operator with one argument).
900 ///
901 /// @param lhs Unary operator's operand.
902 explicit constexpr AsinExpression(ExpressionPtr<Scalar> lhs)
903 : Expression<Scalar>{std::move(lhs)} {}
904
905 Scalar value(Scalar x, Scalar) const override {
906 using std::asin;
907 return asin(x);
908 }
909
910 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
911
912 std::string_view name() const override { return "asin"; }
913
914 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
915 using std::sqrt;
916 return parent_adjoint / sqrt(Scalar(1) - x * x);
917 }
918
921 const ExpressionPtr<Scalar>& parent_adjoint) const override {
922 return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
923 }
924};
925
926/// asin() for Expressions.
927///
928/// @tparam Scalar Scalar type.
929/// @param x The argument.
930template <typename Scalar>
932 using enum ExpressionType;
933 using std::asin;
934
935 // Prune expression
936 if (x->is_constant(Scalar(0))) {
937 // Return zero
938 return x;
939 }
940
941 // Evaluate constant
942 if (x->type() == CONSTANT) {
943 return constant_ptr(asin(x->val));
944 }
945
947}
948
949/// Derived expression type for atan().
950///
951/// @tparam Scalar Scalar type.
952template <typename Scalar>
953struct AtanExpression final : Expression<Scalar> {
954 /// Constructs an unary expression (an operator with one argument).
955 ///
956 /// @param lhs Unary operator's operand.
957 explicit constexpr AtanExpression(ExpressionPtr<Scalar> lhs)
958 : Expression<Scalar>{std::move(lhs)} {}
959
960 Scalar value(Scalar x, Scalar) const override {
961 using std::atan;
962 return atan(x);
963 }
964
965 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
966
967 std::string_view name() const override { return "atan"; }
968
969 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
970 return parent_adjoint / (Scalar(1) + x * x);
971 }
972
975 const ExpressionPtr<Scalar>& parent_adjoint) const override {
976 return parent_adjoint / (constant_ptr(Scalar(1)) + x * x);
977 }
978};
979
980/// atan() for Expressions.
981///
982/// @tparam Scalar Scalar type.
983/// @param x The argument.
984template <typename Scalar>
986 using enum ExpressionType;
987 using std::atan;
988
989 // Prune expression
990 if (x->is_constant(Scalar(0))) {
991 // Return zero
992 return x;
993 }
994
995 // Evaluate constant
996 if (x->type() == CONSTANT) {
997 return constant_ptr(atan(x->val));
998 }
999
1001}
1002
1003/// Derived expression type for atan2().
1004///
1005/// @tparam Scalar Scalar type.
1006template <typename Scalar>
1007struct Atan2Expression final : Expression<Scalar> {
1008 /// Constructs a binary expression (an operator with two arguments).
1009 ///
1010 /// @param lhs Binary operator's left operand.
1011 /// @param rhs Binary operator's right operand.
1014 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1015
1016 Scalar value(Scalar y, Scalar x) const override {
1017 using std::atan2;
1018 return atan2(y, x);
1019 }
1020
1021 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1022
1023 std::string_view name() const override { return "atan2"; }
1024
1025 Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override {
1026 return parent_adjoint * x / (y * y + x * x);
1027 }
1028
1029 Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override {
1030 return parent_adjoint * -y / (y * y + x * x);
1031 }
1032
1035 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1036 return parent_adjoint * x / (y * y + x * x);
1037 }
1038
1041 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1042 return parent_adjoint * -y / (y * y + x * x);
1043 }
1044};
1045
1046/// atan2() for Expressions.
1047///
1048/// @tparam Scalar Scalar type.
1049/// @param y The y argument.
1050/// @param x The x argument.
1051template <typename Scalar>
1053 const ExpressionPtr<Scalar>& x) {
1054 using enum ExpressionType;
1055 using std::atan2;
1056
1057 // Prune expression
1058 if (y->is_constant(Scalar(0))) {
1059 // Return zero
1060 return y;
1061 } else if (x->is_constant(Scalar(0))) {
1062 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1063 }
1064
1065 // Evaluate constant
1066 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1067 return constant_ptr(atan2(y->val, x->val));
1068 }
1069
1071}
1072
1073/// Derived expression type for cos().
1074///
1075/// @tparam Scalar Scalar type.
1076template <typename Scalar>
1077struct CosExpression final : Expression<Scalar> {
1078 /// Constructs an unary expression (an operator with one argument).
1079 ///
1080 /// @param lhs Unary operator's operand.
1081 explicit constexpr CosExpression(ExpressionPtr<Scalar> lhs)
1082 : Expression<Scalar>{std::move(lhs)} {}
1083
1084 Scalar value(Scalar x, Scalar) const override {
1085 using std::cos;
1086 return cos(x);
1087 }
1088
1089 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1090
1091 std::string_view name() const override { return "cos"; }
1092
1093 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1094 using std::sin;
1095 return parent_adjoint * -sin(x);
1096 }
1097
1100 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1101 return parent_adjoint * -sin(x);
1102 }
1103};
1104
1105/// cos() for Expressions.
1106///
1107/// @tparam Scalar Scalar type.
1108/// @param x The argument.
1109template <typename Scalar>
1111 using enum ExpressionType;
1112 using std::cos;
1113
1114 // Prune expression
1115 if (x->is_constant(Scalar(0))) {
1116 return constant_ptr(Scalar(1));
1117 }
1118
1119 // Evaluate constant
1120 if (x->type() == CONSTANT) {
1121 return constant_ptr(cos(x->val));
1122 }
1123
1125}
1126
1127/// Derived expression type for cosh().
1128///
1129/// @tparam Scalar Scalar type.
1130template <typename Scalar>
1131struct CoshExpression final : Expression<Scalar> {
1132 /// Constructs an unary expression (an operator with one argument).
1133 ///
1134 /// @param lhs Unary operator's operand.
1136 : Expression<Scalar>{std::move(lhs)} {}
1137
1138 Scalar value(Scalar x, Scalar) const override {
1139 using std::cosh;
1140 return cosh(x);
1141 }
1142
1143 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1144
1145 std::string_view name() const override { return "cosh"; }
1146
1147 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1148 using std::sinh;
1149 return parent_adjoint * sinh(x);
1150 }
1151
1154 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1155 return parent_adjoint * sinh(x);
1156 }
1157};
1158
1159/// cosh() for Expressions.
1160///
1161/// @tparam Scalar Scalar type.
1162/// @param x The argument.
1163template <typename Scalar>
1165 using enum ExpressionType;
1166 using std::cosh;
1167
1168 // Prune expression
1169 if (x->is_constant(Scalar(0))) {
1170 return constant_ptr(Scalar(1));
1171 }
1172
1173 // Evaluate constant
1174 if (x->type() == CONSTANT) {
1175 return constant_ptr(cosh(x->val));
1176 }
1177
1179}
1180
1181/// Derived expression type for erf().
1182///
1183/// @tparam Scalar Scalar type.
1184template <typename Scalar>
1185struct ErfExpression final : Expression<Scalar> {
1186 /// Constructs an unary expression (an operator with one argument).
1187 ///
1188 /// @param lhs Unary operator's operand.
1189 explicit constexpr ErfExpression(ExpressionPtr<Scalar> lhs)
1190 : Expression<Scalar>{std::move(lhs)} {}
1191
1192 Scalar value(Scalar x, Scalar) const override {
1193 using std::erf;
1194 return erf(x);
1195 }
1196
1197 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1198
1199 std::string_view name() const override { return "erf"; }
1200
1201 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1202 using std::exp;
1203 return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) *
1204 exp(-x * x);
1205 }
1206
1209 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1210 return parent_adjoint *
1211 constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1212 }
1213};
1214
1215/// erf() for Expressions.
1216///
1217/// @tparam Scalar Scalar type.
1218/// @param x The argument.
1219template <typename Scalar>
1221 using enum ExpressionType;
1222 using std::erf;
1223
1224 // Prune expression
1225 if (x->is_constant(Scalar(0))) {
1226 // Return zero
1227 return x;
1228 }
1229
1230 // Evaluate constant
1231 if (x->type() == CONSTANT) {
1232 return constant_ptr(erf(x->val));
1233 }
1234
1236}
1237
1238/// Derived expression type for exp().
1239///
1240/// @tparam Scalar Scalar type.
1241template <typename Scalar>
1242struct ExpExpression final : Expression<Scalar> {
1243 /// Constructs an unary expression (an operator with one argument).
1244 ///
1245 /// @param lhs Unary operator's operand.
1246 explicit constexpr ExpExpression(ExpressionPtr<Scalar> lhs)
1247 : Expression<Scalar>{std::move(lhs)} {}
1248
1249 Scalar value(Scalar x, Scalar) const override {
1250 using std::exp;
1251 return exp(x);
1252 }
1253
1254 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1255
1256 std::string_view name() const override { return "exp"; }
1257
1258 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1259 using std::exp;
1260 return parent_adjoint * exp(x);
1261 }
1262
1265 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1266 return parent_adjoint * exp(x);
1267 }
1268};
1269
1270/// exp() for Expressions.
1271///
1272/// @tparam Scalar Scalar type.
1273/// @param x The argument.
1274template <typename Scalar>
1276 using enum ExpressionType;
1277 using std::exp;
1278
1279 // Prune expression
1280 if (x->is_constant(Scalar(0))) {
1281 return constant_ptr(Scalar(1));
1282 }
1283
1284 // Evaluate constant
1285 if (x->type() == CONSTANT) {
1286 return constant_ptr(exp(x->val));
1287 }
1288
1290}
1291
1292template <typename Scalar>
1294 const ExpressionPtr<Scalar>& y);
1295
1296/// Derived expression type for hypot().
1297///
1298/// @tparam Scalar Scalar type.
1299template <typename Scalar>
1300struct HypotExpression final : Expression<Scalar> {
1301 /// Constructs a binary expression (an operator with two arguments).
1302 ///
1303 /// @param lhs Binary operator's left operand.
1304 /// @param rhs Binary operator's right operand.
1307 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1308
1309 Scalar value(Scalar x, Scalar y) const override {
1310 using std::hypot;
1311 return hypot(x, y);
1312 }
1313
1314 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1315
1316 std::string_view name() const override { return "hypot"; }
1317
1318 Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override {
1319 using std::hypot;
1320 return parent_adjoint * x / hypot(x, y);
1321 }
1322
1323 Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override {
1324 using std::hypot;
1325 return parent_adjoint * y / hypot(x, y);
1326 }
1327
1330 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1331 return parent_adjoint * x / hypot(x, y);
1332 }
1333
1336 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1337 return parent_adjoint * y / hypot(x, y);
1338 }
1339};
1340
1341/// hypot() for Expressions.
1342///
1343/// @tparam Scalar Scalar type.
1344/// @param x The x argument.
1345/// @param y The y argument.
1346template <typename Scalar>
1348 const ExpressionPtr<Scalar>& y) {
1349 using enum ExpressionType;
1350 using std::hypot;
1351
1352 // Prune expression
1353 if (x->is_constant(Scalar(0))) {
1354 return y;
1355 } else if (y->is_constant(Scalar(0))) {
1356 return x;
1357 }
1358
1359 // Evaluate constant
1360 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1361 return constant_ptr(hypot(x->val, y->val));
1362 }
1363
1365}
1366
1367/// Derived expression type for log().
1368///
1369/// @tparam Scalar Scalar type.
1370template <typename Scalar>
1371struct LogExpression final : Expression<Scalar> {
1372 /// Constructs an unary expression (an operator with one argument).
1373 ///
1374 /// @param lhs Unary operator's operand.
1375 explicit constexpr LogExpression(ExpressionPtr<Scalar> lhs)
1376 : Expression<Scalar>{std::move(lhs)} {}
1377
1378 Scalar value(Scalar x, Scalar) const override {
1379 using std::log;
1380 return log(x);
1381 }
1382
1383 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1384
1385 std::string_view name() const override { return "log"; }
1386
1387 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1388 return parent_adjoint / x;
1389 }
1390
1393 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1394 return parent_adjoint / x;
1395 }
1396};
1397
1398/// log() for Expressions.
1399///
1400/// @tparam Scalar Scalar type.
1401/// @param x The argument.
1402template <typename Scalar>
1404 using enum ExpressionType;
1405 using std::log;
1406
1407 // Prune expression
1408 if (x->is_constant(Scalar(0))) {
1409 // Return zero
1410 return x;
1411 }
1412
1413 // Evaluate constant
1414 if (x->type() == CONSTANT) {
1415 return constant_ptr(log(x->val));
1416 }
1417
1419}
1420
1421/// Derived expression type for log10().
1422///
1423/// @tparam Scalar Scalar type.
1424template <typename Scalar>
1425struct Log10Expression final : Expression<Scalar> {
1426 /// Constructs an unary expression (an operator with one argument).
1427 ///
1428 /// @param lhs Unary operator's operand.
1430 : Expression<Scalar>{std::move(lhs)} {}
1431
1432 Scalar value(Scalar x, Scalar) const override {
1433 using std::log10;
1434 return log10(x);
1435 }
1436
1437 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1438
1439 std::string_view name() const override { return "log10"; }
1440
1441 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1442 return parent_adjoint / (Scalar(std::numbers::ln10) * x);
1443 }
1444
1447 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1448 return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x);
1449 }
1450};
1451
1452/// log10() for Expressions.
1453///
1454/// @tparam Scalar Scalar type.
1455/// @param x The argument.
1456template <typename Scalar>
1458 using enum ExpressionType;
1459 using std::log10;
1460
1461 // Prune expression
1462 if (x->is_constant(Scalar(0))) {
1463 // Return zero
1464 return x;
1465 }
1466
1467 // Evaluate constant
1468 if (x->type() == CONSTANT) {
1469 return constant_ptr(log10(x->val));
1470 }
1471
1473}
1474
1475template <typename Scalar>
1477 const ExpressionPtr<Scalar>& power);
1478
1479/// Derived expression type for pow().
1480///
1481/// @tparam Scalar Scalar type.
1482/// @tparam T Expression type.
1483template <typename Scalar, ExpressionType T>
1484struct PowExpression final : Expression<Scalar> {
1485 /// Constructs a binary expression (an operator with two arguments).
1486 ///
1487 /// @param lhs Binary operator's left operand.
1488 /// @param rhs Binary operator's right operand.
1490 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1491
1492 Scalar value(Scalar base, Scalar power) const override {
1493 using std::pow;
1494 return pow(base, power);
1495 }
1496
1497 ExpressionType type() const override { return T; }
1498
1499 std::string_view name() const override { return "pow"; }
1500
1502 Scalar parent_adjoint) const override {
1503 using std::pow;
1504 return parent_adjoint * pow(base, power - Scalar(1)) * power;
1505 }
1506
1508 Scalar parent_adjoint) const override {
1509 using std::log;
1510 using std::pow;
1511
1512 // Since x log(x) -> 0 as x -> 0
1513 if (base == Scalar(0)) {
1514 return Scalar(0);
1515 } else {
1516 return parent_adjoint * pow(base, power) * log(base);
1517 }
1518 }
1519
1521 const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
1522 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1523 return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power;
1524 }
1525
1527 const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
1528 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1529 // Since x log(x) -> 0 as x -> 0
1530 if (base->val == Scalar(0)) {
1531 // Return zero
1532 return base;
1533 } else {
1534 return parent_adjoint * pow(base, power) * log(base);
1535 }
1536 }
1537};
1538
1539/// pow() for Expressions.
1540///
1541/// @tparam Scalar Scalar type.
1542/// @param base The base.
1543/// @param power The power.
1544template <typename Scalar>
1546 const ExpressionPtr<Scalar>& power) {
1547 using enum ExpressionType;
1548 using std::pow;
1549
1550 // Prune expression
1551 if (base->is_constant(Scalar(0))) {
1552 // Return zero
1553 return base;
1554 } else if (base->is_constant(Scalar(1))) {
1555 // Return one
1556 return base;
1557 }
1558 if (power->is_constant(Scalar(0))) {
1559 return constant_ptr(Scalar(1));
1560 } else if (power->is_constant(Scalar(1))) {
1561 return base;
1562 }
1563
1564 // Evaluate constant
1565 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1566 return constant_ptr(pow(base->val, power->val));
1567 }
1568
1569 if (power->is_constant(Scalar(2))) {
1570 if (base->type() == LINEAR) {
1572 } else {
1574 }
1575 }
1576
1578}
1579
1580/// Derived expression type for sign().
1581///
1582/// @tparam Scalar Scalar type.
1583template <typename Scalar>
1584struct SignExpression final : Expression<Scalar> {
1585 /// Constructs an unary expression (an operator with one argument).
1586 ///
1587 /// @param lhs Unary operator's operand.
1589 : Expression<Scalar>{std::move(lhs)} {}
1590
1591 Scalar value(Scalar x, Scalar) const override {
1592 if (x < Scalar(0)) {
1593 return Scalar(-1);
1594 } else if (x == Scalar(0)) {
1595 return Scalar(0);
1596 } else {
1597 return Scalar(1);
1598 }
1599 }
1600
1601 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1602
1603 std::string_view name() const override { return "sign"; }
1604};
1605
1606/// sign() for Expressions.
1607///
1608/// @tparam Scalar Scalar type.
1609/// @param x The argument.
1610template <typename Scalar>
1612 using enum ExpressionType;
1613
1614 // Evaluate constant
1615 if (x->type() == CONSTANT) {
1616 if (x->val < Scalar(0)) {
1617 return constant_ptr(Scalar(-1));
1618 } else if (x->val == Scalar(0)) {
1619 // Return zero
1620 return x;
1621 } else {
1622 return constant_ptr(Scalar(1));
1623 }
1624 }
1625
1627}
1628
1629/// Derived expression type for sin().
1630///
1631/// @tparam Scalar Scalar type.
1632template <typename Scalar>
1633struct SinExpression final : Expression<Scalar> {
1634 /// Constructs an unary expression (an operator with one argument).
1635 ///
1636 /// @param lhs Unary operator's operand.
1637 explicit constexpr SinExpression(ExpressionPtr<Scalar> lhs)
1638 : Expression<Scalar>{std::move(lhs)} {}
1639
1640 Scalar value(Scalar x, Scalar) const override {
1641 using std::sin;
1642 return sin(x);
1643 }
1644
1645 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1646
1647 std::string_view name() const override { return "sin"; }
1648
1649 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1650 using std::cos;
1651 return parent_adjoint * cos(x);
1652 }
1653
1656 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1657 return parent_adjoint * cos(x);
1658 }
1659};
1660
1661/// sin() for Expressions.
1662///
1663/// @tparam Scalar Scalar type.
1664/// @param x The argument.
1665template <typename Scalar>
1667 using enum ExpressionType;
1668 using std::sin;
1669
1670 // Prune expression
1671 if (x->is_constant(Scalar(0))) {
1672 // Return zero
1673 return x;
1674 }
1675
1676 // Evaluate constant
1677 if (x->type() == CONSTANT) {
1678 return constant_ptr(sin(x->val));
1679 }
1680
1682}
1683
1684/// Derived expression type for sinh().
1685///
1686/// @tparam Scalar Scalar type.
1687template <typename Scalar>
1688struct SinhExpression final : Expression<Scalar> {
1689 /// Constructs an unary expression (an operator with one argument).
1690 ///
1691 /// @param lhs Unary operator's operand.
1693 : Expression<Scalar>{std::move(lhs)} {}
1694
1695 Scalar value(Scalar x, Scalar) const override {
1696 using std::sinh;
1697 return sinh(x);
1698 }
1699
1700 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1701
1702 std::string_view name() const override { return "sinh"; }
1703
1704 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1705 using std::cosh;
1706 return parent_adjoint * cosh(x);
1707 }
1708
1711 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1712 return parent_adjoint * cosh(x);
1713 }
1714};
1715
1716/// sinh() for Expressions.
1717///
1718/// @tparam Scalar Scalar type.
1719/// @param x The argument.
1720template <typename Scalar>
1722 using enum ExpressionType;
1723 using std::sinh;
1724
1725 // Prune expression
1726 if (x->is_constant(Scalar(0))) {
1727 // Return zero
1728 return x;
1729 }
1730
1731 // Evaluate constant
1732 if (x->type() == CONSTANT) {
1733 return constant_ptr(sinh(x->val));
1734 }
1735
1737}
1738
1739/// Derived expression type for sqrt().
1740///
1741/// @tparam Scalar Scalar type.
1742template <typename Scalar>
1743struct SqrtExpression final : Expression<Scalar> {
1744 /// Constructs an unary expression (an operator with one argument).
1745 ///
1746 /// @param lhs Unary operator's operand.
1748 : Expression<Scalar>{std::move(lhs)} {}
1749
1750 Scalar value(Scalar x, Scalar) const override {
1751 using std::sqrt;
1752 return sqrt(x);
1753 }
1754
1755 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1756
1757 std::string_view name() const override { return "sqrt"; }
1758
1759 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1760 using std::sqrt;
1761 return parent_adjoint / (Scalar(2) * sqrt(x));
1762 }
1763
1766 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1767 return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x));
1768 }
1769};
1770
1771/// sqrt() for Expressions.
1772///
1773/// @tparam Scalar Scalar type.
1774/// @param x The argument.
1775template <typename Scalar>
1777 using enum ExpressionType;
1778 using std::sqrt;
1779
1780 // Evaluate constant
1781 if (x->type() == CONSTANT) {
1782 if (x->val == Scalar(0)) {
1783 // Return zero
1784 return x;
1785 } else if (x->val == Scalar(1)) {
1786 return x;
1787 } else {
1788 return constant_ptr(sqrt(x->val));
1789 }
1790 }
1791
1793}
1794
1795/// Derived expression type for tan().
1796///
1797/// @tparam Scalar Scalar type.
1798template <typename Scalar>
1799struct TanExpression final : Expression<Scalar> {
1800 /// Constructs an unary expression (an operator with one argument).
1801 ///
1802 /// @param lhs Unary operator's operand.
1803 explicit constexpr TanExpression(ExpressionPtr<Scalar> lhs)
1804 : Expression<Scalar>{std::move(lhs)} {}
1805
1806 Scalar value(Scalar x, Scalar) const override {
1807 using std::tan;
1808 return tan(x);
1809 }
1810
1811 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1812
1813 std::string_view name() const override { return "tan"; }
1814
1815 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1816 using std::cos;
1817
1818 auto c = cos(x);
1819 return parent_adjoint / (c * c);
1820 }
1821
1824 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1825 auto c = cos(x);
1826 return parent_adjoint / (c * c);
1827 }
1828};
1829
1830/// tan() for Expressions.
1831///
1832/// @tparam Scalar Scalar type.
1833/// @param x The argument.
1834template <typename Scalar>
1836 using enum ExpressionType;
1837 using std::tan;
1838
1839 // Prune expression
1840 if (x->is_constant(Scalar(0))) {
1841 // Return zero
1842 return x;
1843 }
1844
1845 // Evaluate constant
1846 if (x->type() == CONSTANT) {
1847 return constant_ptr(tan(x->val));
1848 }
1849
1851}
1852
1853/// Derived expression type for tanh().
1854///
1855/// @tparam Scalar Scalar type.
1856template <typename Scalar>
1857struct TanhExpression final : Expression<Scalar> {
1858 /// Constructs an unary expression (an operator with one argument).
1859 ///
1860 /// @param lhs Unary operator's operand.
1862 : Expression<Scalar>{std::move(lhs)} {}
1863
1864 Scalar value(Scalar x, Scalar) const override {
1865 using std::tanh;
1866 return tanh(x);
1867 }
1868
1869 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1870
1871 std::string_view name() const override { return "tanh"; }
1872
1873 Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
1874 using std::cosh;
1875
1876 auto c = cosh(x);
1877 return parent_adjoint / (c * c);
1878 }
1879
1882 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1883 auto c = cosh(x);
1884 return parent_adjoint / (c * c);
1885 }
1886};
1887
1888/// tanh() for Expressions.
1889///
1890/// @tparam Scalar Scalar type.
1891/// @param x The argument.
1892template <typename Scalar>
1894 using enum ExpressionType;
1895 using std::tanh;
1896
1897 // Prune expression
1898 if (x->is_constant(Scalar(0))) {
1899 // Return zero
1900 return x;
1901 }
1902
1903 // Evaluate constant
1904 if (x->type() == CONSTANT) {
1905 return constant_ptr(tanh(x->val));
1906 }
1907
1909}
1910
1911} // namespace slp::detail
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2846
sign
Definition base.h:689
A custom intrusive shared pointer implementation without thread synchronization overhead.
Definition intrusive_shared_ptr.hpp:27
wpi::util::SmallVector< T > small_vector
Definition small_vector.hpp:10
Definition expression_graph.hpp:11
ExpressionPtr< Scalar > cosh(const ExpressionPtr< Scalar > &x)
cosh() for Expressions.
Definition expression.hpp:1164
constexpr void inc_ref_count(Expression< Scalar > *expr)
Refcount increment for intrusive shared pointer.
Definition expression.hpp:730
ExpressionPtr< Scalar > cos(const ExpressionPtr< Scalar > &x)
cos() for Expressions.
Definition expression.hpp:1110
ExpressionPtr< Scalar > cbrt(const ExpressionPtr< Scalar > &x)
cbrt() for Expressions.
Definition expression.hpp:554
ExpressionPtr< Scalar > sin(const ExpressionPtr< Scalar > &x)
sin() for Expressions.
Definition expression.hpp:1666
ExpressionPtr< Scalar > hypot(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y)
hypot() for Expressions.
Definition expression.hpp:1347
ExpressionPtr< Scalar > pow(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power)
pow() for Expressions.
Definition expression.hpp:1545
ExpressionPtr< Scalar > sqrt(const ExpressionPtr< Scalar > &x)
sqrt() for Expressions.
Definition expression.hpp:1776
ExpressionPtr< Scalar > sinh(const ExpressionPtr< Scalar > &x)
sinh() for Expressions.
Definition expression.hpp:1721
ExpressionPtr< Scalar > erf(const ExpressionPtr< Scalar > &x)
erf() for Expressions.
Definition expression.hpp:1220
ExpressionPtr< Scalar > tan(const ExpressionPtr< Scalar > &x)
tan() for Expressions.
Definition expression.hpp:1835
ExpressionPtr< Scalar > constant_ptr(Scalar value)
Creates an intrusive shared pointer to a constant expression.
Definition expression.hpp:417
ExpressionPtr< Scalar > abs(const ExpressionPtr< Scalar > &x)
abs() for Expressions.
Definition expression.hpp:822
ExpressionPtr< Scalar > tanh(const ExpressionPtr< Scalar > &x)
tanh() for Expressions.
Definition expression.hpp:1893
constexpr bool USE_POOL_ALLOCATOR
Definition expression.hpp:28
ExpressionPtr< Scalar > exp(const ExpressionPtr< Scalar > &x)
exp() for Expressions.
Definition expression.hpp:1275
ExpressionPtr< Scalar > asin(const ExpressionPtr< Scalar > &x)
asin() for Expressions.
Definition expression.hpp:931
ExpressionPtr< Scalar > log(const ExpressionPtr< Scalar > &x)
log() for Expressions.
Definition expression.hpp:1403
void dec_ref_count(Expression< Scalar > *expr)
Refcount decrement for intrusive shared pointer.
Definition expression.hpp:739
ExpressionPtr< Scalar > acos(const ExpressionPtr< Scalar > &x)
acos() for Expressions.
Definition expression.hpp:877
ExpressionPtr< Scalar > atan(const ExpressionPtr< Scalar > &x)
atan() for Expressions.
Definition expression.hpp:985
IntrusiveSharedPtr< Expression< Scalar > > ExpressionPtr
Typedef for intrusive shared pointer to Expression.
Definition expression.hpp:43
ExpressionPtr< Scalar > atan2(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x)
atan2() for Expressions.
Definition expression.hpp:1052
static ExpressionPtr< typename T::Scalar > make_expression_ptr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition expression.hpp:51
ExpressionPtr< Scalar > log10(const ExpressionPtr< Scalar > &x)
log10() for Expressions.
Definition expression.hpp:1457
IntrusiveSharedPtr< T > make_intrusive_shared(Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using args as the paramete...
Definition intrusive_shared_ptr.hpp:260
ExpressionType
Expression type.
Definition expression_type.hpp:16
@ CONSTANT
The expression is a constant.
Definition expression_type.hpp:20
@ QUADRATIC
The expression is composed of quadratic and lower-order operators.
Definition expression_type.hpp:24
@ LINEAR
The expression is composed of linear and lower-order operators.
Definition expression_type.hpp:22
@ NONLINEAR
The expression is composed of nonlinear and lower-order operators.
Definition expression_type.hpp:26
PoolAllocator< T > global_pool_allocator()
Returns an allocator for a global pool memory resource.
Definition pool.hpp:155
IntrusiveSharedPtr< T > allocate_intrusive_shared(Alloc alloc, Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using alloc as the storage...
Definition intrusive_shared_ptr.hpp:274
Definition StringMap.hpp:773
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:792
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:782
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:790
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:794
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:804
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:785
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:865
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:851
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:848
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:856
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:860
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:858
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:914
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:905
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:912
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:919
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:910
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:902
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1033
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1023
Scalar value(Scalar y, Scalar x) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1016
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:1012
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:1029
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:1039
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1021
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1025
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:967
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:973
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:957
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:969
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:965
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:442
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:460
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:466
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:450
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:456
Scalar value(Scalar lhs, Scalar rhs) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:446
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:448
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:452
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:501
Scalar value(Scalar lhs, Scalar rhs) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:507
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:483
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:489
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:491
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:497
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:493
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:532
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:522
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:541
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:530
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:525
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:586
Scalar value(Scalar, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:584
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:588
constexpr ConstantExpression(Scalar value)
Constructs a nullary expression (an operator with no arguments).
Definition expression.hpp:581
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1084
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1089
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1081
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1091
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1098
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1093
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1143
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1145
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1138
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1152
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1135
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1147
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:609
Scalar value(Scalar, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:605
constexpr DecisionVariableExpression(Scalar value)
Constructs a nullary expression (an operator with no arguments).
Definition expression.hpp:602
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:607
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:622
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:627
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:635
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:645
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:629
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:631
Scalar value(Scalar lhs, Scalar rhs) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:625
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:639
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1199
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1189
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1201
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1192
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1207
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1197
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1258
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1256
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1254
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1246
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1263
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1249
An autodiff expression node.
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Unary minus operator.
Definition expression.hpp:314
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:103
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Returns ∂/∂r as a Scalar.
Definition expression.hpp:383
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:113
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:110
constexpr Expression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:126
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Returns ∂/∂l as an Expression.
Definition expression.hpp:395
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Returns ∂/∂r as an Expression.
Definition expression.hpp:408
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Expression-Expression addition operator.
Definition expression.hpp:235
constexpr bool is_constant(Scalar constant) const
Returns true if the expression is the given constant.
Definition expression.hpp:142
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:133
ExpressionPtr< Scalar > adjoint_expr
The adjoint of the expression node, used during gradient expression tree generation.
Definition expression.hpp:107
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Expression-Expression division operator.
Definition expression.hpp:200
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:100
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Expression-Expression subtraction operator.
Definition expression.hpp:277
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Expression-Expression multiplication operator.
Definition expression.hpp:150
virtual ~Expression()=default
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
virtual std::string_view name() const =0
Returns the name of this expression.
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Returns ∂/∂l as a Scalar.
Definition expression.hpp:371
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Expression-Expression compound addition operator.
Definition expression.hpp:268
constexpr Expression(Scalar value)
Constructs a nullary expression (an operator with no arguments).
Definition expression.hpp:121
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Unary plus operator.
Definition expression.hpp:340
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:1305
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:1334
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1314
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1328
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1318
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1316
Scalar value(Scalar x, Scalar y) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1309
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:1323
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1432
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1439
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1429
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1437
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1445
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1441
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1391
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1387
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1383
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1385
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1375
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1378
Scalar value(Scalar lhs, Scalar rhs) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:665
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:681
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:671
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:662
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:667
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:676
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:669
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:688
Scalar value(Scalar base, Scalar power) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1492
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1501
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1497
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1520
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Returns ∂/∂r as a Scalar.
Definition expression.hpp:1507
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂r as an Expression.
Definition expression.hpp:1526
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1499
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Constructs a binary expression (an operator with two arguments).
Definition expression.hpp:1489
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1588
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1591
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1601
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1603
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1649
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1654
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1640
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1647
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1645
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1637
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1704
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1702
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1700
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1709
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1692
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1695
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1755
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1757
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1764
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1747
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1759
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1750
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1822
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1813
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1803
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1815
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1811
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1806
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:1869
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:1871
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:1873
Scalar value(Scalar x, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:1864
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:1880
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:1861
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Constructs an unary expression (an operator with one argument).
Definition expression.hpp:705
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Returns ∂/∂l as an Expression.
Definition expression.hpp:718
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Returns ∂/∂l as a Scalar.
Definition expression.hpp:714
ExpressionType type() const override
Returns the type of this expression (constant, linear, quadratic, or nonlinear).
Definition expression.hpp:710
Scalar value(Scalar lhs, Scalar) const override
Either nullary operator with no arguments, unary operator with one argument, or binary operator with ...
Definition expression.hpp:708
std::string_view name() const override
Returns the name of this expression.
Definition expression.hpp:712