WPILibC++ 2025.1.1
Loading...
Searching...
No Matches
Expression.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <utility>
13
14#include <wpi/SmallVector.h>
15
20
21namespace sleipnir::detail {
22
23// The global pool allocator uses a thread-local static pool resource, which
24// isn't guaranteed to be initialized properly across DLL boundaries on Windows
25#ifdef _WIN32
26inline constexpr bool kUsePoolAllocator = false;
27#else
28inline constexpr bool kUsePoolAllocator = true;
29#endif
30
32
35
36/**
37 * Typedef for intrusive shared pointer to Expression.
38 */
40
41/**
42 * Creates an intrusive shared pointer to an expression from the global pool
43 * allocator.
44 *
45 * @param args Constructor arguments for Expression.
46 */
47template <typename... Args>
48static ExpressionPtr MakeExpressionPtr(Args&&... args) {
49 if constexpr (kUsePoolAllocator) {
51 GlobalPoolAllocator<Expression>(), std::forward<Args>(args)...);
52 } else {
53 return MakeIntrusiveShared<Expression>(std::forward<Args>(args)...);
54 }
55}
56
57/**
58 * An autodiff expression node.
59 */
61 /**
62 * Binary function taking two doubles and returning a double.
63 */
64 using BinaryFuncDouble = double (*)(double, double);
65
66 /**
67 * Trinary function taking three doubles and returning a double.
68 */
69 using TrinaryFuncDouble = double (*)(double, double, double);
70
71 /**
72 * Trinary function taking three expressions and returning an expression.
73 */
75 const ExpressionPtr&,
76 const ExpressionPtr&);
77
78 /// The value of the expression node.
79 double value = 0.0;
80
81 /// The adjoint of the expression node used during autodiff.
82 double adjoint = 0.0;
83
84 /// Tracks the number of instances of this expression yet to be encountered in
85 /// an expression tree.
86 uint32_t duplications = 0;
87
88 /// This expression's row in wrt for autodiff gradient, Jacobian, or Hessian.
89 /// This is -1 if the expression isn't in wrt.
90 int32_t row = -1;
91
92 /// The adjoint of the expression node used during gradient expression tree
93 /// generation.
95
96 /// Expression argument type.
97 ExpressionType type = ExpressionType::kConstant;
98
99 /// Reference count for intrusive shared pointer.
100 uint32_t refCount = 0;
101
102 /// Either nullary operator with no arguments, unary operator with one
103 /// argument, or binary operator with two arguments. This operator is
104 /// used to update the node's value.
105 BinaryFuncDouble valueFunc = nullptr;
106
107 /// Functions returning double adjoints of the children expressions.
108 ///
109 /// Parameters:
110 /// <ul>
111 /// <li>lhs: Left argument to binary operator.</li>
112 /// <li>rhs: Right argument to binary operator.</li>
113 /// <li>parentAdjoint: Adjoint of parent expression.</li>
114 /// </ul>
115 std::array<TrinaryFuncDouble, 2> gradientValueFuncs{nullptr, nullptr};
116
117 /// Functions returning Variable adjoints of the children expressions.
118 ///
119 /// Parameters:
120 /// <ul>
121 /// <li>lhs: Left argument to binary operator.</li>
122 /// <li>rhs: Right argument to binary operator.</li>
123 /// <li>parentAdjoint: Adjoint of parent expression.</li>
124 /// </ul>
125 std::array<TrinaryFuncExpr, 2> gradientFuncs{nullptr, nullptr};
126
127 /// Expression arguments.
128 std::array<ExpressionPtr, 2> args{nullptr, nullptr};
129
130 /**
131 * Constructs a constant expression with a value of zero.
132 */
133 constexpr Expression() = default;
134
135 /**
136 * Constructs a nullary expression (an operator with no arguments).
137 *
138 * @param value The expression value.
139 * @param type The expression type. It should be either constant (the default)
140 * or linear.
141 */
142 explicit constexpr Expression(double value,
143 ExpressionType type = ExpressionType::kConstant)
144 : value{value}, type{type} {}
145
146 /**
147 * Constructs an unary expression (an operator with one argument).
148 *
149 * @param type The expression's type.
150 * @param valueFunc Unary operator that produces this expression's value.
151 * @param lhsGradientValueFunc Gradient with respect to the operand.
152 * @param lhsGradientFunc Gradient with respect to the operand.
153 * @param lhs Unary operator's operand.
154 */
155 constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc,
156 TrinaryFuncDouble lhsGradientValueFunc,
157 TrinaryFuncExpr lhsGradientFunc, ExpressionPtr lhs)
158 : value{valueFunc(lhs->value, 0.0)},
159 type{type},
160 valueFunc{valueFunc},
161 gradientValueFuncs{lhsGradientValueFunc, nullptr},
162 gradientFuncs{lhsGradientFunc, nullptr},
163 args{lhs, nullptr} {}
164
165 /**
166 * Constructs a binary expression (an operator with two arguments).
167 *
168 * @param type The expression's type.
169 * @param valueFunc Unary operator that produces this expression's value.
170 * @param lhsGradientValueFunc Gradient with respect to the left operand.
171 * @param rhsGradientValueFunc Gradient with respect to the right operand.
172 * @param lhsGradientFunc Gradient with respect to the left operand.
173 * @param rhsGradientFunc Gradient with respect to the right operand.
174 * @param lhs Binary operator's left operand.
175 * @param rhs Binary operator's right operand.
176 */
177 constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc,
178 TrinaryFuncDouble lhsGradientValueFunc,
179 TrinaryFuncDouble rhsGradientValueFunc,
180 TrinaryFuncExpr lhsGradientFunc,
181 TrinaryFuncExpr rhsGradientFunc, ExpressionPtr lhs,
182 ExpressionPtr rhs)
183 : value{valueFunc(lhs->value, rhs->value)},
184 type{type},
185 valueFunc{valueFunc},
186 gradientValueFuncs{lhsGradientValueFunc, rhsGradientValueFunc},
187 gradientFuncs{lhsGradientFunc, rhsGradientFunc},
188 args{lhs, rhs} {}
189
190 /**
191 * Returns true if the expression is the given constant.
192 *
193 * @param constant The constant.
194 */
195 constexpr bool IsConstant(double constant) const {
196 return type == ExpressionType::kConstant && value == constant;
197 }
198
199 /**
200 * Expression-Expression multiplication operator.
201 *
202 * @param lhs Operator left-hand side.
203 * @param rhs Operator right-hand side.
204 */
206 const ExpressionPtr& rhs) {
207 using enum ExpressionType;
208
209 // Prune expression
210 if (lhs->IsConstant(0.0)) {
211 // Return zero
212 return lhs;
213 } else if (rhs->IsConstant(0.0)) {
214 // Return zero
215 return rhs;
216 } else if (lhs->IsConstant(1.0)) {
217 return rhs;
218 } else if (rhs->IsConstant(1.0)) {
219 return lhs;
220 }
221
222 // Evaluate constant
223 if (lhs->type == kConstant && rhs->type == kConstant) {
224 return MakeExpressionPtr(lhs->value * rhs->value);
225 }
226
227 // Evaluate expression type
228 ExpressionType type;
229 if (lhs->type == kConstant) {
230 type = rhs->type;
231 } else if (rhs->type == kConstant) {
232 type = lhs->type;
233 } else if (lhs->type == kLinear && rhs->type == kLinear) {
234 type = kQuadratic;
235 } else {
236 type = kNonlinear;
237 }
238
239 return MakeExpressionPtr(
240 type, [](double lhs, double rhs) { return lhs * rhs; },
241 [](double, double rhs, double parentAdjoint) {
242 return parentAdjoint * rhs;
243 },
244 [](double lhs, double, double parentAdjoint) {
245 return parentAdjoint * lhs;
246 },
247 [](const ExpressionPtr&, const ExpressionPtr& rhs,
248 const ExpressionPtr& parentAdjoint) { return parentAdjoint * rhs; },
249 [](const ExpressionPtr& lhs, const ExpressionPtr&,
250 const ExpressionPtr& parentAdjoint) { return parentAdjoint * lhs; },
251 lhs, rhs);
252 }
253
254 /**
255 * Expression-Expression division operator.
256 *
257 * @param lhs Operator left-hand side.
258 * @param rhs Operator right-hand side.
259 */
261 const ExpressionPtr& rhs) {
262 using enum ExpressionType;
263
264 // Prune expression
265 if (lhs->IsConstant(0.0)) {
266 // Return zero
267 return lhs;
268 } else if (rhs->IsConstant(1.0)) {
269 return lhs;
270 }
271
272 // Evaluate constant
273 if (lhs->type == kConstant && rhs->type == kConstant) {
274 return MakeExpressionPtr(lhs->value / rhs->value);
275 }
276
277 // Evaluate expression type
278 ExpressionType type;
279 if (rhs->type == kConstant) {
280 type = lhs->type;
281 } else {
282 type = kNonlinear;
283 }
284
285 return MakeExpressionPtr(
286 type, [](double lhs, double rhs) { return lhs / rhs; },
287 [](double, double rhs, double parentAdjoint) {
288 return parentAdjoint / rhs;
289 },
290 [](double lhs, double rhs, double parentAdjoint) {
291 return parentAdjoint * -lhs / (rhs * rhs);
292 },
293 [](const ExpressionPtr&, const ExpressionPtr& rhs,
294 const ExpressionPtr& parentAdjoint) { return parentAdjoint / rhs; },
295 [](const ExpressionPtr& lhs, const ExpressionPtr& rhs,
296 const ExpressionPtr& parentAdjoint) {
297 return parentAdjoint * -lhs / (rhs * rhs);
298 },
299 lhs, rhs);
300 }
301
302 /**
303 * Expression-Expression addition operator.
304 *
305 * @param lhs Operator left-hand side.
306 * @param rhs Operator right-hand side.
307 */
309 const ExpressionPtr& rhs) {
310 using enum ExpressionType;
311
312 // Prune expression
313 if (lhs == nullptr || lhs->IsConstant(0.0)) {
314 return rhs;
315 } else if (rhs == nullptr || rhs->IsConstant(0.0)) {
316 return lhs;
317 }
318
319 // Evaluate constant
320 if (lhs->type == kConstant && rhs->type == kConstant) {
321 return MakeExpressionPtr(lhs->value + rhs->value);
322 }
323
324 return MakeExpressionPtr(
325 std::max(lhs->type, rhs->type),
326 [](double lhs, double rhs) { return lhs + rhs; },
327 [](double, double, double parentAdjoint) { return parentAdjoint; },
328 [](double, double, double parentAdjoint) { return parentAdjoint; },
329 [](const ExpressionPtr&, const ExpressionPtr&,
330 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
331 [](const ExpressionPtr&, const ExpressionPtr&,
332 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
333 lhs, rhs);
334 }
335
336 /**
337 * Expression-Expression subtraction operator.
338 *
339 * @param lhs Operator left-hand side.
340 * @param rhs Operator right-hand side.
341 */
343 const ExpressionPtr& rhs) {
344 using enum ExpressionType;
345
346 // Prune expression
347 if (lhs->IsConstant(0.0)) {
348 if (rhs->IsConstant(0.0)) {
349 // Return zero
350 return rhs;
351 } else {
352 return -rhs;
353 }
354 } else if (rhs->IsConstant(0.0)) {
355 return lhs;
356 }
357
358 // Evaluate constant
359 if (lhs->type == kConstant && rhs->type == kConstant) {
360 return MakeExpressionPtr(lhs->value - rhs->value);
361 }
362
363 return MakeExpressionPtr(
364 std::max(lhs->type, rhs->type),
365 [](double lhs, double rhs) { return lhs - rhs; },
366 [](double, double, double parentAdjoint) { return parentAdjoint; },
367 [](double, double, double parentAdjoint) { return -parentAdjoint; },
368 [](const ExpressionPtr&, const ExpressionPtr&,
369 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
370 [](const ExpressionPtr&, const ExpressionPtr&,
371 const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
372 lhs, rhs);
373 }
374
375 /**
376 * Unary minus operator.
377 *
378 * @param lhs Operand of unary minus.
379 */
381 using enum ExpressionType;
382
383 // Prune expression
384 if (lhs->IsConstant(0.0)) {
385 // Return zero
386 return lhs;
387 }
388
389 // Evaluate constant
390 if (lhs->type == kConstant) {
391 return MakeExpressionPtr(-lhs->value);
392 }
393
394 return MakeExpressionPtr(
395 lhs->type, [](double lhs, double) { return -lhs; },
396 [](double, double, double parentAdjoint) { return -parentAdjoint; },
397 [](const ExpressionPtr&, const ExpressionPtr&,
398 const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
399 lhs);
400 }
401
402 /**
403 * Unary plus operator.
404 *
405 * @param lhs Operand of unary plus.
406 */
408 return lhs;
409 }
410};
411
416
417/**
418 * Refcount increment for intrusive shared pointer.
419 *
420 * @param expr The shared pointer's managed object.
421 */
423 ++expr->refCount;
424}
425
426/**
427 * Refcount decrement for intrusive shared pointer.
428 *
429 * @param expr The shared pointer's managed object.
430 */
432 // If a deeply nested tree is being deallocated all at once, calling the
433 // Expression destructor when expr's refcount reaches zero can cause a stack
434 // overflow. Instead, we iterate over its children to decrement their
435 // refcounts and deallocate them.
437 stack.emplace_back(expr);
438
439 while (!stack.empty()) {
440 auto elem = stack.back();
441 stack.pop_back();
442
443 // Decrement the current node's refcount. If it reaches zero, deallocate the
444 // node and enqueue its children so their refcounts are decremented too.
445 if (--elem->refCount == 0) {
446 if (elem->adjointExpr != nullptr) {
447 stack.emplace_back(elem->adjointExpr.Get());
448 }
449 for (auto&& arg : elem->args) {
450 if (arg != nullptr) {
451 stack.emplace_back(arg.Get());
452 }
453 }
454
455 // Not calling the destructor here is safe because it only decrements
456 // refcounts, which was already done above.
457 if constexpr (kUsePoolAllocator) {
458 auto alloc = GlobalPoolAllocator<Expression>();
459 std::allocator_traits<decltype(alloc)>::deallocate(alloc, elem,
460 sizeof(Expression));
461 }
462 }
463 }
464}
465
466/**
467 * std::abs() for Expressions.
468 *
469 * @param x The argument.
470 */
472 const ExpressionPtr& x) {
473 using enum ExpressionType;
474
475 // Prune expression
476 if (x->IsConstant(0.0)) {
477 // Return zero
478 return x;
479 }
480
481 // Evaluate constant
482 if (x->type == kConstant) {
483 return MakeExpressionPtr(std::abs(x->value));
484 }
485
486 return MakeExpressionPtr(
487 kNonlinear, [](double x, double) { return std::abs(x); },
488 [](double x, double, double parentAdjoint) {
489 if (x < 0.0) {
490 return -parentAdjoint;
491 } else if (x > 0.0) {
492 return parentAdjoint;
493 } else {
494 return 0.0;
495 }
496 },
497 [](const ExpressionPtr& x, const ExpressionPtr&,
498 const ExpressionPtr& parentAdjoint) {
499 if (x->value < 0.0) {
500 return -parentAdjoint;
501 } else if (x->value > 0.0) {
502 return parentAdjoint;
503 } else {
504 // Return zero
505 return MakeExpressionPtr();
506 }
507 },
508 x);
509}
510
511/**
512 * std::acos() for Expressions.
513 *
514 * @param x The argument.
515 */
517 const ExpressionPtr& x) {
518 using enum ExpressionType;
519
520 // Prune expression
521 if (x->IsConstant(0.0)) {
522 return MakeExpressionPtr(std::numbers::pi / 2.0);
523 }
524
525 // Evaluate constant
526 if (x->type == kConstant) {
527 return MakeExpressionPtr(std::acos(x->value));
528 }
529
530 return MakeExpressionPtr(
531 kNonlinear, [](double x, double) { return std::acos(x); },
532 [](double x, double, double parentAdjoint) {
533 return -parentAdjoint / std::sqrt(1.0 - x * x);
534 },
535 [](const ExpressionPtr& x, const ExpressionPtr&,
536 const ExpressionPtr& parentAdjoint) {
537 return -parentAdjoint /
539 },
540 x);
541}
542
543/**
544 * std::asin() for Expressions.
545 *
546 * @param x The argument.
547 */
549 const ExpressionPtr& x) {
550 using enum ExpressionType;
551
552 // Prune expression
553 if (x->IsConstant(0.0)) {
554 // Return zero
555 return x;
556 }
557
558 // Evaluate constant
559 if (x->type == kConstant) {
560 return MakeExpressionPtr(std::asin(x->value));
561 }
562
563 return MakeExpressionPtr(
564 kNonlinear, [](double x, double) { return std::asin(x); },
565 [](double x, double, double parentAdjoint) {
566 return parentAdjoint / std::sqrt(1.0 - x * x);
567 },
568 [](const ExpressionPtr& x, const ExpressionPtr&,
569 const ExpressionPtr& parentAdjoint) {
570 return parentAdjoint /
572 },
573 x);
574}
575
576/**
577 * std::atan() for Expressions.
578 *
579 * @param x The argument.
580 */
582 const ExpressionPtr& x) {
583 using enum ExpressionType;
584
585 // Prune expression
586 if (x->IsConstant(0.0)) {
587 // Return zero
588 return x;
589 }
590
591 // Evaluate constant
592 if (x->type == kConstant) {
593 return MakeExpressionPtr(std::atan(x->value));
594 }
595
596 return MakeExpressionPtr(
597 kNonlinear, [](double x, double) { return std::atan(x); },
598 [](double x, double, double parentAdjoint) {
599 return parentAdjoint / (1.0 + x * x);
600 },
601 [](const ExpressionPtr& x, const ExpressionPtr&,
602 const ExpressionPtr& parentAdjoint) {
603 return parentAdjoint / (MakeExpressionPtr(1.0) + x * x);
604 },
605 x);
606}
607
608/**
609 * std::atan2() for Expressions.
610 *
611 * @param y The y argument.
612 * @param x The x argument.
613 */
615 const ExpressionPtr& y, const ExpressionPtr& x) {
616 using enum ExpressionType;
617
618 // Prune expression
619 if (y->IsConstant(0.0)) {
620 // Return zero
621 return y;
622 } else if (x->IsConstant(0.0)) {
623 return MakeExpressionPtr(std::numbers::pi / 2.0);
624 }
625
626 // Evaluate constant
627 if (y->type == kConstant && x->type == kConstant) {
628 return MakeExpressionPtr(std::atan2(y->value, x->value));
629 }
630
631 return MakeExpressionPtr(
632 kNonlinear, [](double y, double x) { return std::atan2(y, x); },
633 [](double y, double x, double parentAdjoint) {
634 return parentAdjoint * x / (y * y + x * x);
635 },
636 [](double y, double x, double parentAdjoint) {
637 return parentAdjoint * -y / (y * y + x * x);
638 },
639 [](const ExpressionPtr& y, const ExpressionPtr& x,
640 const ExpressionPtr& parentAdjoint) {
641 return parentAdjoint * x / (y * y + x * x);
642 },
643 [](const ExpressionPtr& y, const ExpressionPtr& x,
644 const ExpressionPtr& parentAdjoint) {
645 return parentAdjoint * -y / (y * y + x * x);
646 },
647 y, x);
648}
649
650/**
651 * std::cos() for Expressions.
652 *
653 * @param x The argument.
654 */
656 const ExpressionPtr& x) {
657 using enum ExpressionType;
658
659 // Prune expression
660 if (x->IsConstant(0.0)) {
661 return MakeExpressionPtr(1.0);
662 }
663
664 // Evaluate constant
665 if (x->type == kConstant) {
666 return MakeExpressionPtr(std::cos(x->value));
667 }
668
669 return MakeExpressionPtr(
670 kNonlinear, [](double x, double) { return std::cos(x); },
671 [](double x, double, double parentAdjoint) {
672 return -parentAdjoint * std::sin(x);
673 },
674 [](const ExpressionPtr& x, const ExpressionPtr&,
675 const ExpressionPtr& parentAdjoint) {
676 return parentAdjoint * -sleipnir::detail::sin(x);
677 },
678 x);
679}
680
681/**
682 * std::cosh() for Expressions.
683 *
684 * @param x The argument.
685 */
687 const ExpressionPtr& x) {
688 using enum ExpressionType;
689
690 // Prune expression
691 if (x->IsConstant(0.0)) {
692 return MakeExpressionPtr(1.0);
693 }
694
695 // Evaluate constant
696 if (x->type == kConstant) {
697 return MakeExpressionPtr(std::cosh(x->value));
698 }
699
700 return MakeExpressionPtr(
701 kNonlinear, [](double x, double) { return std::cosh(x); },
702 [](double x, double, double parentAdjoint) {
703 return parentAdjoint * std::sinh(x);
704 },
705 [](const ExpressionPtr& x, const ExpressionPtr&,
706 const ExpressionPtr& parentAdjoint) {
707 return parentAdjoint * sleipnir::detail::sinh(x);
708 },
709 x);
710}
711
712/**
713 * std::erf() for Expressions.
714 *
715 * @param x The argument.
716 */
718 const ExpressionPtr& x) {
719 using enum ExpressionType;
720
721 // Prune expression
722 if (x->IsConstant(0.0)) {
723 // Return zero
724 return x;
725 }
726
727 // Evaluate constant
728 if (x->type == kConstant) {
729 return MakeExpressionPtr(std::erf(x->value));
730 }
731
732 return MakeExpressionPtr(
733 kNonlinear, [](double x, double) { return std::erf(x); },
734 [](double x, double, double parentAdjoint) {
735 return parentAdjoint * 2.0 * std::numbers::inv_sqrtpi *
736 std::exp(-x * x);
737 },
738 [](const ExpressionPtr& x, const ExpressionPtr&,
739 const ExpressionPtr& parentAdjoint) {
740 return parentAdjoint *
741 MakeExpressionPtr(2.0 * std::numbers::inv_sqrtpi) *
742 sleipnir::detail::exp(-x * x);
743 },
744 x);
745}
746
747/**
748 * std::exp() for Expressions.
749 *
750 * @param x The argument.
751 */
753 const ExpressionPtr& x) {
754 using enum ExpressionType;
755
756 // Prune expression
757 if (x->IsConstant(0.0)) {
758 return MakeExpressionPtr(1.0);
759 }
760
761 // Evaluate constant
762 if (x->type == kConstant) {
763 return MakeExpressionPtr(std::exp(x->value));
764 }
765
766 return MakeExpressionPtr(
767 kNonlinear, [](double x, double) { return std::exp(x); },
768 [](double x, double, double parentAdjoint) {
769 return parentAdjoint * std::exp(x);
770 },
771 [](const ExpressionPtr& x, const ExpressionPtr&,
772 const ExpressionPtr& parentAdjoint) {
773 return parentAdjoint * sleipnir::detail::exp(x);
774 },
775 x);
776}
777
778/**
779 * std::hypot() for Expressions.
780 *
781 * @param x The x argument.
782 * @param y The y argument.
783 */
785 const ExpressionPtr& x, const ExpressionPtr& y) {
786 using enum ExpressionType;
787
788 // Prune expression
789 if (x->IsConstant(0.0)) {
790 return y;
791 } else if (y->IsConstant(0.0)) {
792 return x;
793 }
794
795 // Evaluate constant
796 if (x->type == kConstant && y->type == kConstant) {
797 return MakeExpressionPtr(std::hypot(x->value, y->value));
798 }
799
800 return MakeExpressionPtr(
801 kNonlinear, [](double x, double y) { return std::hypot(x, y); },
802 [](double x, double y, double parentAdjoint) {
803 return parentAdjoint * x / std::hypot(x, y);
804 },
805 [](double x, double y, double parentAdjoint) {
806 return parentAdjoint * y / std::hypot(x, y);
807 },
808 [](const ExpressionPtr& x, const ExpressionPtr& y,
809 const ExpressionPtr& parentAdjoint) {
810 return parentAdjoint * x / sleipnir::detail::hypot(x, y);
811 },
812 [](const ExpressionPtr& x, const ExpressionPtr& y,
813 const ExpressionPtr& parentAdjoint) {
814 return parentAdjoint * y / sleipnir::detail::hypot(x, y);
815 },
816 x, y);
817}
818
819/**
820 * std::log() for Expressions.
821 *
822 * @param x The argument.
823 */
825 const ExpressionPtr& x) {
826 using enum ExpressionType;
827
828 // Prune expression
829 if (x->IsConstant(0.0)) {
830 // Return zero
831 return x;
832 }
833
834 // Evaluate constant
835 if (x->type == kConstant) {
836 return MakeExpressionPtr(std::log(x->value));
837 }
838
839 return MakeExpressionPtr(
840 kNonlinear, [](double x, double) { return std::log(x); },
841 [](double x, double, double parentAdjoint) { return parentAdjoint / x; },
842 [](const ExpressionPtr& x, const ExpressionPtr&,
843 const ExpressionPtr& parentAdjoint) { return parentAdjoint / x; },
844 x);
845}
846
847/**
848 * std::log10() for Expressions.
849 *
850 * @param x The argument.
851 */
853 const ExpressionPtr& x) {
854 using enum ExpressionType;
855
856 // Prune expression
857 if (x->IsConstant(0.0)) {
858 // Return zero
859 return x;
860 }
861
862 // Evaluate constant
863 if (x->type == kConstant) {
864 return MakeExpressionPtr(std::log10(x->value));
865 }
866
867 return MakeExpressionPtr(
868 kNonlinear, [](double x, double) { return std::log10(x); },
869 [](double x, double, double parentAdjoint) {
870 return parentAdjoint / (std::numbers::ln10 * x);
871 },
872 [](const ExpressionPtr& x, const ExpressionPtr&,
873 const ExpressionPtr& parentAdjoint) {
874 return parentAdjoint / (MakeExpressionPtr(std::numbers::ln10) * x);
875 },
876 x);
877}
878
879/**
880 * std::pow() for Expressions.
881 *
882 * @param base The base.
883 * @param power The power.
884 */
886 const ExpressionPtr& base, const ExpressionPtr& power) {
887 using enum ExpressionType;
888
889 // Prune expression
890 if (base->IsConstant(0.0)) {
891 // Return zero
892 return base;
893 } else if (base->IsConstant(1.0)) {
894 return base;
895 }
896 if (power->IsConstant(0.0)) {
897 return MakeExpressionPtr(1.0);
898 } else if (power->IsConstant(1.0)) {
899 return base;
900 }
901
902 // Evaluate constant
903 if (base->type == kConstant && power->type == kConstant) {
904 return MakeExpressionPtr(std::pow(base->value, power->value));
905 }
906
907 return MakeExpressionPtr(
908 base->type == kLinear && power->IsConstant(2.0) ? kQuadratic : kNonlinear,
909 [](double base, double power) { return std::pow(base, power); },
910 [](double base, double power, double parentAdjoint) {
911 return parentAdjoint * std::pow(base, power - 1) * power;
912 },
913 [](double base, double power, double parentAdjoint) {
914 // Since x * std::log(x) -> 0 as x -> 0
915 if (base == 0.0) {
916 return 0.0;
917 } else {
918 return parentAdjoint * std::pow(base, power - 1) * base *
919 std::log(base);
920 }
921 },
922 [](const ExpressionPtr& base, const ExpressionPtr& power,
923 const ExpressionPtr& parentAdjoint) {
924 return parentAdjoint *
925 sleipnir::detail::pow(base, power - MakeExpressionPtr(1.0)) *
926 power;
927 },
928 [](const ExpressionPtr& base, const ExpressionPtr& power,
929 const ExpressionPtr& parentAdjoint) {
930 // Since x * std::log(x) -> 0 as x -> 0
931 if (base->value == 0.0) {
932 // Return zero
933 return base;
934 } else {
935 return parentAdjoint *
936 sleipnir::detail::pow(base, power - MakeExpressionPtr(1.0)) *
937 base * sleipnir::detail::log(base);
938 }
939 },
940 base, power);
941}
942
943/**
944 * sign() for Expressions.
945 *
946 * @param x The argument.
947 */
949 using enum ExpressionType;
950
951 // Evaluate constant
952 if (x->type == kConstant) {
953 if (x->value < 0.0) {
954 return MakeExpressionPtr(-1.0);
955 } else if (x->value == 0.0) {
956 // Return zero
957 return x;
958 } else {
959 return MakeExpressionPtr(1.0);
960 }
961 }
962
963 return MakeExpressionPtr(
965 [](double x, double) {
966 if (x < 0.0) {
967 return -1.0;
968 } else if (x == 0.0) {
969 return 0.0;
970 } else {
971 return 1.0;
972 }
973 },
974 [](double, double, double) { return 0.0; },
975 [](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr&) {
976 // Return zero
977 return MakeExpressionPtr();
978 },
979 x);
980}
981
982/**
983 * std::sin() for Expressions.
984 *
985 * @param x The argument.
986 */
988 const ExpressionPtr& x) {
989 using enum ExpressionType;
990
991 // Prune expression
992 if (x->IsConstant(0.0)) {
993 // Return zero
994 return x;
995 }
996
997 // Evaluate constant
998 if (x->type == kConstant) {
999 return MakeExpressionPtr(std::sin(x->value));
1000 }
1001
1002 return MakeExpressionPtr(
1003 kNonlinear, [](double x, double) { return std::sin(x); },
1004 [](double x, double, double parentAdjoint) {
1005 return parentAdjoint * std::cos(x);
1006 },
1007 [](const ExpressionPtr& x, const ExpressionPtr&,
1008 const ExpressionPtr& parentAdjoint) {
1009 return parentAdjoint * sleipnir::detail::cos(x);
1010 },
1011 x);
1012}
1013
1014/**
1015 * std::sinh() for Expressions.
1016 *
1017 * @param x The argument.
1018 */
1020 using enum ExpressionType;
1021
1022 // Prune expression
1023 if (x->IsConstant(0.0)) {
1024 // Return zero
1025 return x;
1026 }
1027
1028 // Evaluate constant
1029 if (x->type == kConstant) {
1030 return MakeExpressionPtr(std::sinh(x->value));
1031 }
1032
1033 return MakeExpressionPtr(
1034 kNonlinear, [](double x, double) { return std::sinh(x); },
1035 [](double x, double, double parentAdjoint) {
1036 return parentAdjoint * std::cosh(x);
1037 },
1038 [](const ExpressionPtr& x, const ExpressionPtr&,
1039 const ExpressionPtr& parentAdjoint) {
1040 return parentAdjoint * sleipnir::detail::cosh(x);
1041 },
1042 x);
1043}
1044
1045/**
1046 * std::sqrt() for Expressions.
1047 *
1048 * @param x The argument.
1049 */
1051 const ExpressionPtr& x) {
1052 using enum ExpressionType;
1053
1054 // Evaluate constant
1055 if (x->type == kConstant) {
1056 if (x->value == 0.0) {
1057 // Return zero
1058 return x;
1059 } else if (x->value == 1.0) {
1060 return x;
1061 } else {
1062 return MakeExpressionPtr(std::sqrt(x->value));
1063 }
1064 }
1065
1066 return MakeExpressionPtr(
1067 kNonlinear, [](double x, double) { return std::sqrt(x); },
1068 [](double x, double, double parentAdjoint) {
1069 return parentAdjoint / (2.0 * std::sqrt(x));
1070 },
1071 [](const ExpressionPtr& x, const ExpressionPtr&,
1072 const ExpressionPtr& parentAdjoint) {
1073 return parentAdjoint /
1075 },
1076 x);
1077}
1078
1079/**
1080 * std::tan() for Expressions.
1081 *
1082 * @param x The argument.
1083 */
1085 const ExpressionPtr& x) {
1086 using enum ExpressionType;
1087
1088 // Prune expression
1089 if (x->IsConstant(0.0)) {
1090 // Return zero
1091 return x;
1092 }
1093
1094 // Evaluate constant
1095 if (x->type == kConstant) {
1096 return MakeExpressionPtr(std::tan(x->value));
1097 }
1098
1099 return MakeExpressionPtr(
1100 kNonlinear, [](double x, double) { return std::tan(x); },
1101 [](double x, double, double parentAdjoint) {
1102 return parentAdjoint / (std::cos(x) * std::cos(x));
1103 },
1104 [](const ExpressionPtr& x, const ExpressionPtr&,
1105 const ExpressionPtr& parentAdjoint) {
1106 return parentAdjoint /
1108 },
1109 x);
1110}
1111
1112/**
1113 * std::tanh() for Expressions.
1114 *
1115 * @param x The argument.
1116 */
1118 using enum ExpressionType;
1119
1120 // Prune expression
1121 if (x->IsConstant(0.0)) {
1122 // Return zero
1123 return x;
1124 }
1125
1126 // Evaluate constant
1127 if (x->type == kConstant) {
1128 return MakeExpressionPtr(std::tanh(x->value));
1129 }
1130
1131 return MakeExpressionPtr(
1132 kNonlinear, [](double x, double) { return std::tanh(x); },
1133 [](double x, double, double parentAdjoint) {
1134 return parentAdjoint / (std::cosh(x) * std::cosh(x));
1135 },
1136 [](const ExpressionPtr& x, const ExpressionPtr&,
1137 const ExpressionPtr& parentAdjoint) {
1138 return parentAdjoint /
1140 },
1141 x);
1142}
1143
1144} // namespace sleipnir::detail
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void pop_back()
Definition SmallVector.h:441
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
SLEIPNIR_DLLEXPORT ExpressionPtr sin(const ExpressionPtr &x)
std::sin() for Expressions.
Definition Expression.hpp:987
SLEIPNIR_DLLEXPORT ExpressionPtr abs(const ExpressionPtr &x)
std::abs() for Expressions.
Definition Expression.hpp:471
SLEIPNIR_DLLEXPORT ExpressionPtr log10(const ExpressionPtr &x)
std::log10() for Expressions.
Definition Expression.hpp:852
SLEIPNIR_DLLEXPORT ExpressionPtr asin(const ExpressionPtr &x)
std::asin() for Expressions.
Definition Expression.hpp:548
SLEIPNIR_DLLEXPORT ExpressionPtr exp(const ExpressionPtr &x)
std::exp() for Expressions.
Definition Expression.hpp:752
static ExpressionPtr MakeExpressionPtr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition Expression.hpp:48
SLEIPNIR_DLLEXPORT ExpressionPtr sinh(const ExpressionPtr &x)
std::sinh() for Expressions.
Definition Expression.hpp:1019
SLEIPNIR_DLLEXPORT ExpressionPtr log(const ExpressionPtr &x)
std::log() for Expressions.
Definition Expression.hpp:824
SLEIPNIR_DLLEXPORT ExpressionPtr hypot(const ExpressionPtr &x, const ExpressionPtr &y)
std::hypot() for Expressions.
Definition Expression.hpp:784
SLEIPNIR_DLLEXPORT ExpressionPtr cosh(const ExpressionPtr &x)
std::cosh() for Expressions.
Definition Expression.hpp:686
SLEIPNIR_DLLEXPORT ExpressionPtr acos(const ExpressionPtr &x)
std::acos() for Expressions.
Definition Expression.hpp:516
SLEIPNIR_DLLEXPORT ExpressionPtr pow(const ExpressionPtr &base, const ExpressionPtr &power)
std::pow() for Expressions.
Definition Expression.hpp:885
SLEIPNIR_DLLEXPORT ExpressionPtr atan2(const ExpressionPtr &y, const ExpressionPtr &x)
std::atan2() for Expressions.
Definition Expression.hpp:614
constexpr bool kUsePoolAllocator
Definition Expression.hpp:28
IntrusiveSharedPtr< Expression > ExpressionPtr
Typedef for intrusive shared pointer to Expression.
Definition Expression.hpp:39
SLEIPNIR_DLLEXPORT ExpressionPtr atan(const ExpressionPtr &x)
std::atan() for Expressions.
Definition Expression.hpp:581
SLEIPNIR_DLLEXPORT ExpressionPtr cos(const ExpressionPtr &x)
std::cos() for Expressions.
Definition Expression.hpp:655
void IntrusiveSharedPtrDecRefCount(Expression *expr)
Refcount decrement for intrusive shared pointer.
Definition Expression.hpp:431
void IntrusiveSharedPtrIncRefCount(Expression *expr)
Refcount increment for intrusive shared pointer.
Definition Expression.hpp:422
SLEIPNIR_DLLEXPORT ExpressionPtr sqrt(const ExpressionPtr &x)
std::sqrt() for Expressions.
Definition Expression.hpp:1050
SLEIPNIR_DLLEXPORT ExpressionPtr erf(const ExpressionPtr &x)
std::erf() for Expressions.
Definition Expression.hpp:717
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using alloc as the storage...
Definition IntrusiveSharedPtr.hpp:209
SLEIPNIR_DLLEXPORT Variable sin(const Variable &x)
std::sin() for Variables.
Definition Variable.hpp:379
ExpressionType
Expression type.
Definition ExpressionType.hpp:14
@ kConstant
The expression is a constant.
@ kLinear
The expression is composed of linear and lower-order operators.
@ kNonlinear
The expression is composed of nonlinear and lower-order operators.
@ kQuadratic
The expression is composed of quadratic and lower-order operators.
PoolAllocator< T > GlobalPoolAllocator()
Returns an allocator for a global pool memory resource.
Definition Pool.hpp:158
IntrusiveSharedPtr< T > MakeIntrusiveShared(Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using args as the paramete...
Definition IntrusiveSharedPtr.hpp:193
SLEIPNIR_DLLEXPORT Variable sqrt(const Variable &x)
std::sqrt() for Variables.
Definition Variable.hpp:397
An autodiff expression node.
Definition Expression.hpp:60
ExpressionType type
Expression argument type.
Definition Expression.hpp:97
constexpr Expression(double value, ExpressionType type=ExpressionType::kConstant)
Constructs a nullary expression (an operator with no arguments).
Definition Expression.hpp:142
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs)
Unary minus operator.
Definition Expression.hpp:380
ExpressionPtr(*)(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) TrinaryFuncExpr
Trinary function taking three expressions and returning an expression.
Definition Expression.hpp:74
double(*)(double, double, double) TrinaryFuncDouble
Trinary function taking three doubles and returning a double.
Definition Expression.hpp:69
constexpr Expression()=default
Constructs a constant expression with a value of zero.
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression division operator.
Definition Expression.hpp:260
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression subtraction operator.
Definition Expression.hpp:342
constexpr bool IsConstant(double constant) const
Returns true if the expression is the given constant.
Definition Expression.hpp:195
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs)
Unary plus operator.
Definition Expression.hpp:407
ExpressionPtr adjointExpr
The adjoint of the expression node used during gradient expression tree generation.
Definition Expression.hpp:94
uint32_t refCount
Reference count for intrusive shared pointer.
Definition Expression.hpp:100
double(*)(double, double) BinaryFuncDouble
Binary function taking two doubles and returning a double.
Definition Expression.hpp:64
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression multiplication operator.
Definition Expression.hpp:205
double value
The value of the expression node.
Definition Expression.hpp:79
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression addition operator.
Definition Expression.hpp:308
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, ExpressionPtr lhs)
Constructs an unary expression (an operator with one argument).
Definition Expression.hpp:155
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncDouble rhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, TrinaryFuncExpr rhsGradientFunc, ExpressionPtr lhs, ExpressionPtr rhs)
Constructs a binary expression (an operator with two arguments).
Definition Expression.hpp:177
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775
sign
Definition base.h:685