47template <
typename... Args>
 
   86  uint32_t duplications = 0;
 
  100  uint32_t refCount = 0;
 
  115  std::array<TrinaryFuncDouble, 2> gradientValueFuncs{
nullptr, 
nullptr};
 
  125  std::array<TrinaryFuncExpr, 2> gradientFuncs{
nullptr, 
nullptr};
 
  128  std::array<ExpressionPtr, 2> args{
nullptr, 
nullptr};
 
  144      : value{value}, type{type} {}
 
 
  158      : value{valueFunc(lhs->value, 0.0)},
 
  160        valueFunc{valueFunc},
 
  161        gradientValueFuncs{lhsGradientValueFunc, nullptr},
 
  162        gradientFuncs{lhsGradientFunc, nullptr},
 
  163        args{lhs, nullptr} {}
 
 
  183      : value{valueFunc(lhs->value, rhs->value)},
 
  185        valueFunc{valueFunc},
 
  186        gradientValueFuncs{lhsGradientValueFunc, rhsGradientValueFunc},
 
  187        gradientFuncs{lhsGradientFunc, rhsGradientFunc},
 
 
  196    return type == ExpressionType::kConstant && value == constant;
 
 
  210    if (lhs->IsConstant(0.0)) {
 
  213    } 
else if (rhs->IsConstant(0.0)) {
 
  216    } 
else if (lhs->IsConstant(1.0)) {
 
  218    } 
else if (rhs->IsConstant(1.0)) {
 
  240        type, [](
double lhs, 
double rhs) { 
return lhs * rhs; },
 
  241        [](double, 
double rhs, 
double parentAdjoint) {
 
  242          return parentAdjoint * rhs;
 
  244        [](
double lhs, double, 
double parentAdjoint) {
 
  245          return parentAdjoint * lhs;
 
  248           const ExpressionPtr& parentAdjoint) { 
return parentAdjoint * rhs; },
 
  250           const ExpressionPtr& parentAdjoint) { 
return parentAdjoint * lhs; },
 
 
  265    if (lhs->IsConstant(0.0)) {
 
  268    } 
else if (rhs->IsConstant(1.0)) {
 
  286        type, [](
double lhs, 
double rhs) { 
return lhs / rhs; },
 
  287        [](double, 
double rhs, 
double parentAdjoint) {
 
  288          return parentAdjoint / rhs;
 
  290        [](
double lhs, 
double rhs, 
double parentAdjoint) {
 
  291          return parentAdjoint * -lhs / (rhs * rhs);
 
  294           const ExpressionPtr& parentAdjoint) { 
return parentAdjoint / rhs; },
 
  297          return parentAdjoint * -lhs / (rhs * rhs);
 
 
  313    if (lhs == 
nullptr || lhs->IsConstant(0.0)) {
 
  315    } 
else if (rhs == 
nullptr || rhs->IsConstant(0.0)) {
 
  325        std::max(lhs->type, rhs->type),
 
  326        [](
double lhs, 
double rhs) { return lhs + rhs; },
 
  327        [](
double, 
double, 
double parentAdjoint) { return parentAdjoint; },
 
  328        [](
double, 
double, 
double parentAdjoint) { return parentAdjoint; },
 
  330           const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
 
  332           const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
 
 
  347    if (lhs->IsConstant(0.0)) {
 
  348      if (rhs->IsConstant(0.0)) {
 
  354    } 
else if (rhs->IsConstant(0.0)) {
 
  364        std::max(lhs->type, rhs->type),
 
  365        [](
double lhs, 
double rhs) { return lhs - rhs; },
 
  366        [](
double, 
double, 
double parentAdjoint) { return parentAdjoint; },
 
  367        [](
double, 
double, 
double parentAdjoint) { return -parentAdjoint; },
 
  369           const ExpressionPtr& parentAdjoint) { return parentAdjoint; },
 
  371           const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
 
 
  384    if (lhs->IsConstant(0.0)) {
 
  395        lhs->type, [](
double lhs, 
double) { return -lhs; },
 
  396        [](
double, 
double, 
double parentAdjoint) { return -parentAdjoint; },
 
  398           const ExpressionPtr& parentAdjoint) { return -parentAdjoint; },
 
 
 
  439  while (!stack.
empty()) {
 
  440    auto elem = stack.
back();
 
  445    if (--elem->refCount == 0) {
 
  446      if (elem->adjointExpr != 
nullptr) {
 
  449      for (
auto&& 
arg : elem->args) {
 
  450        if (
arg != 
nullptr) {
 
  459        std::allocator_traits<
decltype(alloc)>::deallocate(alloc, elem,
 
 
  487      kNonlinear, [](
double x, 
double) { 
return std::abs(x); },
 
  488      [](
double x, double, 
double parentAdjoint) {
 
  490          return -parentAdjoint;
 
  491        } 
else if (x > 0.0) {
 
  492          return parentAdjoint;
 
  499        if (x->
value < 0.0) {
 
  500          return -parentAdjoint;
 
  501        } 
else if (x->
value > 0.0) {
 
  502          return parentAdjoint;
 
 
  531      kNonlinear, [](
double x, 
double) { 
return std::acos(x); },
 
  532      [](
double x, double, 
double parentAdjoint) {
 
  533        return -parentAdjoint / std::sqrt(1.0 - x * x);
 
  537        return -parentAdjoint /
 
 
  564      kNonlinear, [](
double x, 
double) { 
return std::asin(x); },
 
  565      [](
double x, double, 
double parentAdjoint) {
 
  566        return parentAdjoint / std::sqrt(1.0 - x * x);
 
  570        return parentAdjoint /
 
 
  597      kNonlinear, [](
double x, 
double) { 
return std::atan(x); },
 
  598      [](
double x, double, 
double parentAdjoint) {
 
  599        return parentAdjoint / (1.0 + x * x);
 
 
  632      kNonlinear, [](
double y, 
double x) { 
return std::atan2(y, x); },
 
  633      [](
double y, 
double x, 
double parentAdjoint) {
 
  634        return parentAdjoint * x / (y * y + x * x);
 
  636      [](
double y, 
double x, 
double parentAdjoint) {
 
  637        return parentAdjoint * -y / (y * y + x * x);
 
  641        return parentAdjoint * x / (y * y + x * x);
 
  645        return parentAdjoint * -y / (y * y + x * x);
 
 
  670      kNonlinear, [](
double x, 
double) { 
return std::cos(x); },
 
  671      [](
double x, double, 
double parentAdjoint) {
 
  672        return -parentAdjoint * std::sin(x);
 
 
  701      kNonlinear, [](
double x, 
double) { 
return std::cosh(x); },
 
  702      [](
double x, double, 
double parentAdjoint) {
 
  703        return parentAdjoint * std::sinh(x);
 
 
  733      kNonlinear, [](
double x, 
double) { 
return std::erf(x); },
 
  734      [](
double x, double, 
double parentAdjoint) {
 
  735        return parentAdjoint * 2.0 * std::numbers::inv_sqrtpi *
 
  740        return parentAdjoint *
 
 
  767      kNonlinear, [](
double x, 
double) { 
return std::exp(x); },
 
  768      [](
double x, double, 
double parentAdjoint) {
 
  769        return parentAdjoint * std::exp(x);
 
 
  801      kNonlinear, [](
double x, 
double y) { 
return std::hypot(x, y); },
 
  802      [](
double x, 
double y, 
double parentAdjoint) {
 
  803        return parentAdjoint * x / std::hypot(x, y);
 
  805      [](
double x, 
double y, 
double parentAdjoint) {
 
  806        return parentAdjoint * y / std::hypot(x, y);
 
 
  840      kNonlinear, [](
double x, 
double) { 
return std::log(x); },
 
  841      [](
double x, double, 
double parentAdjoint) { 
return parentAdjoint / x; },
 
  843         const ExpressionPtr& parentAdjoint) { 
return parentAdjoint / x; },
 
 
  868      kNonlinear, [](
double x, 
double) { 
return std::log10(x); },
 
  869      [](
double x, double, 
double parentAdjoint) {
 
  870        return parentAdjoint / (std::numbers::ln10 * x);
 
 
  909      [](
double base, 
double power) { return std::pow(base, power); },
 
  910      [](
double base, 
double power, 
double parentAdjoint) {
 
  911        return parentAdjoint * std::pow(base, power - 1) * power;
 
  913      [](
double base, 
double power, 
double parentAdjoint) {
 
  918          return parentAdjoint * std::pow(base, power - 1) * base *
 
  924        return parentAdjoint *
 
  931        if (base->value == 0.0) {
 
  935          return parentAdjoint *
 
 
  953    if (x->value < 0.0) {
 
  955    } 
else if (x->value == 0.0) {
 
  965      [](
double x, 
double) {
 
  968        } 
else if (x == 0.0) {
 
  974      [](double, double, double) { 
return 0.0; },
 
 
  992  if (x->IsConstant(0.0)) {
 
 1003      kNonlinear, [](
double x, 
double) { 
return std::sin(x); },
 
 1004      [](
double x, double, 
double parentAdjoint) {
 
 1005        return parentAdjoint * std::cos(x);
 
 
 1023  if (x->IsConstant(0.0)) {
 
 1034      kNonlinear, [](
double x, 
double) { 
return std::sinh(x); },
 
 1035      [](
double x, double, 
double parentAdjoint) {
 
 1036        return parentAdjoint * std::cosh(x);
 
 
 1056    if (x->value == 0.0) {
 
 1059    } 
else if (x->value == 1.0) {
 
 1067      kNonlinear, [](
double x, 
double) { 
return std::sqrt(x); },
 
 1068      [](
double x, double, 
double parentAdjoint) {
 
 1069        return parentAdjoint / (2.0 * std::sqrt(x));
 
 1073        return parentAdjoint /
 
 
 1089  if (x->IsConstant(0.0)) {
 
 1100      kNonlinear, [](
double x, 
double) { 
return std::tan(x); },
 
 1101      [](
double x, double, 
double parentAdjoint) {
 
 1102        return parentAdjoint / (std::cos(x) * std::cos(x));
 
 1106        return parentAdjoint /
 
 
 1121  if (x->IsConstant(0.0)) {
 
 1132      kNonlinear, [](
double x, 
double) { 
return std::tanh(x); },
 
 1133      [](
double x, double, 
double parentAdjoint) {
 
 1134        return parentAdjoint / (std::cosh(x) * std::cosh(x));
 
 1138        return parentAdjoint /
 
 
This file defines the SmallVector class.
#define SLEIPNIR_DLLEXPORT
Definition SymbolExports.hpp:34
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition SmallVector.h:1212
reference emplace_back(ArgTypes &&... Args)
Definition SmallVector.h:953
void pop_back()
Definition SmallVector.h:441
bool empty() const
Definition SmallVector.h:102
reference back()
Definition SmallVector.h:324
Definition ExpressionGraph.hpp:13
SLEIPNIR_DLLEXPORT ExpressionPtr sin(const ExpressionPtr &x)
std::sin() for Expressions.
Definition Expression.hpp:987
SLEIPNIR_DLLEXPORT ExpressionPtr abs(const ExpressionPtr &x)
std::abs() for Expressions.
Definition Expression.hpp:471
SLEIPNIR_DLLEXPORT ExpressionPtr log10(const ExpressionPtr &x)
std::log10() for Expressions.
Definition Expression.hpp:852
SLEIPNIR_DLLEXPORT ExpressionPtr asin(const ExpressionPtr &x)
std::asin() for Expressions.
Definition Expression.hpp:548
SLEIPNIR_DLLEXPORT ExpressionPtr exp(const ExpressionPtr &x)
std::exp() for Expressions.
Definition Expression.hpp:752
static ExpressionPtr MakeExpressionPtr(Args &&... args)
Creates an intrusive shared pointer to an expression from the global pool allocator.
Definition Expression.hpp:48
SLEIPNIR_DLLEXPORT ExpressionPtr sinh(const ExpressionPtr &x)
std::sinh() for Expressions.
Definition Expression.hpp:1019
SLEIPNIR_DLLEXPORT ExpressionPtr log(const ExpressionPtr &x)
std::log() for Expressions.
Definition Expression.hpp:824
SLEIPNIR_DLLEXPORT ExpressionPtr hypot(const ExpressionPtr &x, const ExpressionPtr &y)
std::hypot() for Expressions.
Definition Expression.hpp:784
SLEIPNIR_DLLEXPORT ExpressionPtr cosh(const ExpressionPtr &x)
std::cosh() for Expressions.
Definition Expression.hpp:686
SLEIPNIR_DLLEXPORT ExpressionPtr acos(const ExpressionPtr &x)
std::acos() for Expressions.
Definition Expression.hpp:516
SLEIPNIR_DLLEXPORT ExpressionPtr pow(const ExpressionPtr &base, const ExpressionPtr &power)
std::pow() for Expressions.
Definition Expression.hpp:885
SLEIPNIR_DLLEXPORT ExpressionPtr atan2(const ExpressionPtr &y, const ExpressionPtr &x)
std::atan2() for Expressions.
Definition Expression.hpp:614
constexpr bool kUsePoolAllocator
Definition Expression.hpp:28
IntrusiveSharedPtr< Expression > ExpressionPtr
Typedef for intrusive shared pointer to Expression.
Definition Expression.hpp:39
SLEIPNIR_DLLEXPORT ExpressionPtr atan(const ExpressionPtr &x)
std::atan() for Expressions.
Definition Expression.hpp:581
SLEIPNIR_DLLEXPORT ExpressionPtr cos(const ExpressionPtr &x)
std::cos() for Expressions.
Definition Expression.hpp:655
void IntrusiveSharedPtrDecRefCount(Expression *expr)
Refcount decrement for intrusive shared pointer.
Definition Expression.hpp:431
void IntrusiveSharedPtrIncRefCount(Expression *expr)
Refcount increment for intrusive shared pointer.
Definition Expression.hpp:422
SLEIPNIR_DLLEXPORT ExpressionPtr sqrt(const ExpressionPtr &x)
std::sqrt() for Expressions.
Definition Expression.hpp:1050
SLEIPNIR_DLLEXPORT ExpressionPtr erf(const ExpressionPtr &x)
std::erf() for Expressions.
Definition Expression.hpp:717
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using alloc as the storage...
Definition IntrusiveSharedPtr.hpp:209
SLEIPNIR_DLLEXPORT Variable sin(const Variable &x)
std::sin() for Variables.
Definition Variable.hpp:379
ExpressionType
Expression type.
Definition ExpressionType.hpp:14
@ kConstant
The expression is a constant.
@ kLinear
The expression is composed of linear and lower-order operators.
@ kNonlinear
The expression is composed of nonlinear and lower-order operators.
@ kQuadratic
The expression is composed of quadratic and lower-order operators.
PoolAllocator< T > GlobalPoolAllocator()
Returns an allocator for a global pool memory resource.
Definition Pool.hpp:158
IntrusiveSharedPtr< T > MakeIntrusiveShared(Args &&... args)
Constructs an object of type T and wraps it in an intrusive shared pointer using args as the paramete...
Definition IntrusiveSharedPtr.hpp:193
SLEIPNIR_DLLEXPORT Variable sqrt(const Variable &x)
std::sqrt() for Variables.
Definition Variable.hpp:397
An autodiff expression node.
Definition Expression.hpp:60
ExpressionType type
Expression argument type.
Definition Expression.hpp:97
constexpr Expression(double value, ExpressionType type=ExpressionType::kConstant)
Constructs a nullary expression (an operator with no arguments).
Definition Expression.hpp:142
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs)
Unary minus operator.
Definition Expression.hpp:380
ExpressionPtr(*)(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) TrinaryFuncExpr
Trinary function taking three expressions and returning an expression.
Definition Expression.hpp:74
double(*)(double, double, double) TrinaryFuncDouble
Trinary function taking three doubles and returning a double.
Definition Expression.hpp:69
constexpr Expression()=default
Constructs a constant expression with a value of zero.
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression division operator.
Definition Expression.hpp:260
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression subtraction operator.
Definition Expression.hpp:342
constexpr bool IsConstant(double constant) const
Returns true if the expression is the given constant.
Definition Expression.hpp:195
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs)
Unary plus operator.
Definition Expression.hpp:407
ExpressionPtr adjointExpr
The adjoint of the expression node used during gradient expression tree generation.
Definition Expression.hpp:94
uint32_t refCount
Reference count for intrusive shared pointer.
Definition Expression.hpp:100
double(*)(double, double) BinaryFuncDouble
Binary function taking two doubles and returning a double.
Definition Expression.hpp:64
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression multiplication operator.
Definition Expression.hpp:205
double value
The value of the expression node.
Definition Expression.hpp:79
friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Expression-Expression addition operator.
Definition Expression.hpp:308
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, ExpressionPtr lhs)
Constructs an unary expression (an operator with one argument).
Definition Expression.hpp:155
constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, TrinaryFuncDouble lhsGradientValueFunc, TrinaryFuncDouble rhsGradientValueFunc, TrinaryFuncExpr lhsGradientFunc, TrinaryFuncExpr rhsGradientFunc, ExpressionPtr lhs, ExpressionPtr rhs)
Constructs a binary expression (an operator with two arguments).
Definition Expression.hpp:177
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Returns a named argument to be used in a formatting function.
Definition base.h:2775
sign
Definition base.h:685