40 const T *rhs_data,
const Shape &incoming_shape,
const T *incoming_data,
41 const Shape &lhs_grad_shape, T *lhs_grad_data,
42 const Shape &rhs_grad_shape, T *rhs_grad_data,
45 if (!(lhs_shape == rhs_shape && rhs_shape == incoming_shape && incoming_shape == lhs_grad_shape &&
46 lhs_grad_shape == rhs_grad_shape))
47 throw std::runtime_error{
"Shape of lhs, rhs, incoming, lhs_grad, and rhs_grad must match"};
49 switch (arithmetic_type)
53 BroadcastTo(incoming_shape,
const_cast<T *
>(incoming_data), lhs_grad_shape, lhs_grad_data);
54 BroadcastTo(incoming_shape,
const_cast<T *
>(incoming_data), rhs_grad_shape, rhs_grad_data);
60 BroadcastTo(incoming_shape,
const_cast<T *
>(incoming_data), lhs_grad_shape, lhs_grad_data);
62 auto const in_map =
MapAsVector(incoming_data, incoming_shape);
63 auto rhs_grad_map =
MapAsVector(rhs_grad_data, rhs_grad_shape);
64 rhs_grad_map = -in_map;
70 auto const in_map =
MapAsVector(incoming_data, incoming_shape);
71 auto const lhs_map =
MapAsVector(lhs_data, lhs_shape);
72 auto const rhs_map =
MapAsVector(rhs_data, rhs_shape);
73 auto lhs_grad_map =
MapAsVector(lhs_grad_data, lhs_grad_shape);
74 auto rhs_grad_map =
MapAsVector(rhs_grad_data, rhs_grad_shape);
76 lhs_grad_map = in_map.array() * rhs_map.array();
77 rhs_grad_map = in_map.array() * lhs_map.array();
83 throw std::runtime_error{
"Unsupported Binary Arithmetic Operation"};
void BinaryArithmeticGrad(const Shape &lhs_shape, const T *lhs_data, const Shape &rhs_shape, const T *rhs_data, const Shape &incoming_shape, const T *incoming_data, const Shape &lhs_grad_shape, T *lhs_grad_data, const Shape &rhs_grad_shape, T *rhs_grad_data, ArithmeticType arithmetic_type)