47template <
typename... Args>
86 uint32_t duplications = 0;
100 uint32_t refCount = 0;
115 std::array<TrinaryFuncDouble, 2> gradientValueFuncs{
nullptr,
nullptr};
125 std::array<TrinaryFuncExpr, 2> gradientFuncs{
nullptr,
nullptr};
128 std::array<ExpressionPtr, 2> args{
nullptr,
nullptr};
144 : value{value}, type{type} {}
158 : value{valueFunc(lhs->value, 0.0)},
160 valueFunc{valueFunc},
161 gradientValueFuncs{lhsGradientValueFunc, nullptr},
162 gradientFuncs{lhsGradientFunc, nullptr},
163 args{lhs, nullptr} {}
183 : value{valueFunc(lhs->value, rhs->value)},
185 valueFunc{valueFunc},
186 gradientValueFuncs{lhsGradientValueFunc, rhsGradientValueFunc},
187 gradientFuncs{lhsGradientFunc, rhsGradientFunc},
196 return type == ExpressionType::kConstant && value == constant;
210 if (lhs->IsConstant(0.0)) {
213 }
else if (rhs->IsConstant(0.0)) {
216 }
else if (lhs->IsConstant(1.0)) {
218 }
else if (rhs->IsConstant(1.0)) {
240 type, [](
double lhs,
double rhs) {
return lhs * rhs; },
241 [](double,
double rhs,
double parentAdjoint) {
242 return parentAdjoint * rhs;
244 [](
double lhs, double,
double parentAdjoint) {
245 return parentAdjoint * lhs;
248 const ExpressionPtr& parentAdjoint) {
return parentAdjoint * rhs; },
250 const ExpressionPtr& parentAdjoint) {
return parentAdjoint * lhs; },
265 if (lhs->IsConstant(0.0)) {
268 }
else if (rhs->IsConstant(1.0)) {
286 type, [](
double lhs,
double rhs) {
return lhs / rhs; },
287 [](double,
double rhs,
double parentAdjoint) {
288 return parentAdjoint / rhs;
290 [](
double lhs,
double rhs,
double parentAdjoint) {
291 return parentAdjoint * -lhs / (rhs * rhs);
294 const ExpressionPtr& parentAdjoint) {
return parentAdjoint / rhs; },
297 return parentAdjoint * -lhs / (rhs * rhs);
313 if (lhs ==
nullptr || lhs->IsConstant(0.0)) {
315 }
else if (rhs ==
nullptr || rhs->IsConstant(0.0)) {
325 std::max(lhs->type, rhs->type),
326 [](
double lhs,
double rhs) { return lhs + rhs; },
327 [](
double,
double,
double parentAdjoint) { return parentAdjoint; },
328 [](
double,
double,
double parentAdjoint) { return parentAdjoint; },
330 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
332 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
347 if (lhs->IsConstant(0.0)) {
348 if (rhs->IsConstant(0.0)) {
354 }
else if (rhs->IsConstant(0.0)) {
364 std::max(lhs->type, rhs->type),
365 [](
double lhs,
double rhs) { return lhs - rhs; },
366 [](
double,
double,
double parentAdjoint) { return parentAdjoint; },
367 [](
double,
double,
double parentAdjoint) { return -parentAdjoint; },
369 const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
371 const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
384 if (lhs->IsConstant(0.0)) {
395 lhs->type, [](
double lhs,
double) { return -lhs; },
396 [](
double,
double,
double parentAdjoint) { return -parentAdjoint; },
398 const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
439 while (!stack.
empty()) {
440 auto elem = stack.
back();
445 if (--elem->refCount == 0) {
446 if (elem->adjointExpr !=
nullptr) {
449 for (
auto&&
arg : elem->args) {
450 if (
arg !=
nullptr) {
459 std::allocator_traits<
decltype(alloc)>::deallocate(alloc, elem,
487 kNonlinear, [](
double x,
double) {
return std::abs(x); },
488 [](
double x, double,
double parentAdjoint) {
490 return -parentAdjoint;
491 }
else if (x > 0.0) {
492 return parentAdjoint;
499 if (x->
value < 0.0) {
500 return -parentAdjoint;
501 }
else if (x->
value > 0.0) {
502 return parentAdjoint;
531 kNonlinear, [](
double x,
double) {
return std::acos(x); },
532 [](
double x, double,
double parentAdjoint) {
533 return -parentAdjoint / std::sqrt(1.0 - x * x);
537 return -parentAdjoint /
564 kNonlinear, [](
double x,
double) {
return std::asin(x); },
565 [](
double x, double,
double parentAdjoint) {
566 return parentAdjoint / std::sqrt(1.0 - x * x);
570 return parentAdjoint /
597 kNonlinear, [](
double x,
double) {
return std::atan(x); },
598 [](
double x, double,
double parentAdjoint) {
599 return parentAdjoint / (1.0 + x * x);
632 kNonlinear, [](
double y,
double x) {
return std::atan2(y, x); },
633 [](
double y,
double x,
double parentAdjoint) {
634 return parentAdjoint * x / (y * y + x * x);
636 [](
double y,
double x,
double parentAdjoint) {
637 return parentAdjoint * -y / (y * y + x * x);
641 return parentAdjoint * x / (y * y + x * x);
645 return parentAdjoint * -y / (y * y + x * x);
670 kNonlinear, [](
double x,
double) {
return std::cos(x); },
671 [](
double x, double,
double parentAdjoint) {
672 return -parentAdjoint * std::sin(x);
701 kNonlinear, [](
double x,
double) {
return std::cosh(x); },
702 [](
double x, double,
double parentAdjoint) {
703 return parentAdjoint * std::sinh(x);
733 kNonlinear, [](
double x,
double) {
return std::erf(x); },
734 [](
double x, double,
double parentAdjoint) {
735 return parentAdjoint * 2.0 * std::numbers::inv_sqrtpi *
740 return parentAdjoint *
767 kNonlinear, [](
double x,
double) {
return std::exp(x); },
768 [](
double x, double,
double parentAdjoint) {
769 return parentAdjoint * std::exp(x);
801 kNonlinear, [](
double x,
double y) {
return std::hypot(x, y); },
802 [](
double x,
double y,
double parentAdjoint) {
803 return parentAdjoint * x / std::hypot(x, y);
805 [](
double x,
double y,
double parentAdjoint) {
806 return parentAdjoint * y / std::hypot(x, y);
840 kNonlinear, [](
double x,
double) {
return std::log(x); },
841 [](
double x, double,
double parentAdjoint) {
return parentAdjoint / x; },
843 const ExpressionPtr& parentAdjoint) {
return parentAdjoint / x; },
868 kNonlinear, [](
double x,
double) {
return std::log10(x); },
869 [](
double x, double,
double parentAdjoint) {
870 return parentAdjoint / (std::numbers::ln10 * x);
909 [](
double base,
double power) { return std::pow(base, power); },
910 [](
double base,
double power,
double parentAdjoint) {
911 return parentAdjoint * std::pow(base, power - 1) * power;
913 [](
double base,
double power,
double parentAdjoint) {
918 return parentAdjoint * std::pow(base, power - 1) * base *
924 return parentAdjoint *
931 if (base->value == 0.0) {
935 return parentAdjoint *
953 if (x->value < 0.0) {
955 }
else if (x->value == 0.0) {
965 [](
double x,
double) {
968 }
else if (x == 0.0) {
974 [](double, double, double) {
return 0.0; },
992 if (x->IsConstant(0.0)) {
1003 kNonlinear, [](
double x,
double) {
return std::sin(x); },
1004 [](
double x, double,
double parentAdjoint) {
1005 return parentAdjoint * std::cos(x);
1023 if (x->IsConstant(0.0)) {
1034 kNonlinear, [](
double x,
double) {
return std::sinh(x); },
1035 [](
double x, double,
double parentAdjoint) {
1036 return parentAdjoint * std::cosh(x);
1056 if (x->value == 0.0) {
1059 }
else if (x->value == 1.0) {
1067 kNonlinear, [](
double x,
double) {
return std::sqrt(x); },
1068 [](
double x, double,
double parentAdjoint) {
1069 return parentAdjoint / (2.0 * std::sqrt(x));
1073 return parentAdjoint /
1089 if (x->IsConstant(0.0)) {
1100 kNonlinear, [](
double x,
double) {
return std::tan(x); },
1101 [](
double x, double,
double parentAdjoint) {
1102 return parentAdjoint / (std::cos(x) * std::cos(x));
1106 return parentAdjoint /
1121 if (x->IsConstant(0.0)) {
1132 kNonlinear, [](
double x,
double) {
return std::tanh(x); },
1133 [](
double x, double,
double parentAdjoint) {
1134 return parentAdjoint / (std::cosh(x) * std::cosh(x));
1138 return parentAdjoint /
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void pop_back()
Definition SmallVector.h:441
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
SLEIPNIR_DLLEXPORT ExpressionPtr sin(const ExpressionPtr &x)
std::sin() for Expressions.
Definition Expression.hpp:987
SLEIPNIR_DLLEXPORT ExpressionPtr abs(const ExpressionPtr &x)
std::abs() for Expressions.
Definition Expression.hpp:471
SLEIPNIR_DLLEXPORT ExpressionPtr log10(const ExpressionPtr &x)
std::log10() for Expressions.
Definition Expression.hpp:852
SLEIPNIR_DLLEXPORT ExpressionPtr asin(const ExpressionPtr &x)
std::asin() for Expressions.
Definition Expression.hpp:548
SLEIPNIR_DLLEXPORT ExpressionPtr exp(const ExpressionPtr &x)
std::exp() for Expressions.
Definition Expression.hpp:752
static ExpressionPtr MakeExpressionPtr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition Expression.hpp:48
SLEIPNIR_DLLEXPORT ExpressionPtr sinh(const ExpressionPtr &x)
std::sinh() for Expressions.
Definition Expression.hpp:1019
SLEIPNIR_DLLEXPORT ExpressionPtr log(const ExpressionPtr &x)
std::log() for Expressions.
Definition Expression.hpp:824
SLEIPNIR_DLLEXPORT ExpressionPtr hypot(const ExpressionPtr &x, const ExpressionPtr &y)
std::hypot() for Expressions.
Definition Expression.hpp:784
SLEIPNIR_DLLEXPORT ExpressionPtr cosh(const ExpressionPtr &x)
std::cosh() for Expressions.
Definition Expression.hpp:686
SLEIPNIR_DLLEXPORT ExpressionPtr acos(const ExpressionPtr &x)
std::acos() for Expressions.
Definition Expression.hpp:516
SLEIPNIR_DLLEXPORT ExpressionPtr pow(const ExpressionPtr &base, const ExpressionPtr &power)
std::pow() for Expressions.
Definition Expression.hpp:885
SLEIPNIR_DLLEXPORT ExpressionPtr atan2(const ExpressionPtr &y, const ExpressionPtr &x)
std::atan2() for Expressions.
Definition Expression.hpp:614
constexpr bool kUsePoolAllocator
Definition Expression.hpp:28
IntrusiveSharedPtr< Expression > ExpressionPtr
Typedef for intrusive shared pointer to Expression.
Definition Expression.hpp:39
SLEIPNIR_DLLEXPORT ExpressionPtr atan(const ExpressionPtr &x)
std::atan() for Expressions.
Definition Expression.hpp:581
SLEIPNIR_DLLEXPORT ExpressionPtr cos(const ExpressionPtr &x)
std::cos() for Expressions.
Definition Expression.hpp:655
void IntrusiveSharedPtrDecRefCount(Expression *expr)
Refcount decrement for intrusive shared pointer.
Definition Expression.hpp:431
void IntrusiveSharedPtrIncRefCount(Expression *expr)
Refcount increment for intrusive shared pointer.
Definition Expression.hpp:422
SLEIPNIR_DLLEXPORT ExpressionPtr sqrt(const ExpressionPtr &x)
std::sqrt() for Expressions.
Definition Expression.hpp:1050
SLEIPNIR_DLLEXPORT ExpressionPtr erf(const ExpressionPtr &x)
std::erf() for Expressions.
Definition Expression.hpp:717
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using alloc as the storage...
Definition IntrusiveSharedPtr.hpp:209
SLEIPNIR_DLLEXPORT Variable sin(const Variable &x)
std::sin() for Variables.
Definition Variable.hpp:379
ExpressionType
Expression type.
Definition ExpressionType.hpp:14
@ kConstant
The expression is a constant.
@ kLinear
The expression is composed of linear and lower-order operators.
@ kNonlinear
The expression is composed of nonlinear and lower-order operators.
@ kQuadratic
The expression is composed of quadratic and lower-order operators.
PoolAllocator< T > GlobalPoolAllocator()
Returns an allocator for a global pool memory resource.
Definition Pool.hpp:158
IntrusiveSharedPtr< T > MakeIntrusiveShared(Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using args as the paramete...
Definition IntrusiveSharedPtr.hpp:193
SLEIPNIR_DLLEXPORT Variable sqrt(const Variable &x)
std::sqrt() for Variables.
Definition Variable.hpp:397
An autodiff expression node.
Definition Expression.hpp:60
ExpressionType type
Expression argument type.
Definition Expression.hpp:97
constexpr Expression(double value, ExpressionType type=ExpressionType::kConstant)
Constructs a nullary expression (an operator with no arguments).
Definition Expression.hpp:142
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs)
Unary minus operator.
Definition Expression.hpp:380
ExpressionPtr(*)(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) TrinaryFuncExpr
Trinary function taking three expressions and returning an expression.
Definition Expression.hpp:74
double(*)(double, double, double) TrinaryFuncDouble
Trinary function taking three doubles and returning a double.
Definition Expression.hpp:69
constexpr Expression()=default
Constructs a constant expression with a value of zero.
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression division operator.
Definition Expression.hpp:260
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression subtraction operator.
Definition Expression.hpp:342
constexpr bool IsConstant(double constant) const
Returns true if the expression is the given constant.
Definition Expression.hpp:195
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs)
Unary plus operator.
Definition Expression.hpp:407
ExpressionPtr adjointExpr
The adjoint of the expression node used during gradient expression tree generation.
Definition Expression.hpp:94
uint32_t refCount
Reference count for intrusive shared pointer.
Definition Expression.hpp:100
double(*)(double, double) BinaryFuncDouble
Binary function taking two doubles and returning a double.
Definition Expression.hpp:64
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression multiplication operator.
Definition Expression.hpp:205
double value
The value of the expression node.
Definition Expression.hpp:79
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression addition operator.
Definition Expression.hpp:308
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, ExpressionPtr lhs)
Constructs an unary expression (an operator with one argument).
Definition Expression.hpp:155
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncDouble rhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, TrinaryFuncExpr rhsGradientFunc, ExpressionPtr lhs, ExpressionPtr rhs)
Constructs a binary expression (an operator with two arguments).
Definition Expression.hpp:177
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775
sign
Definition base.h:685