18#include "kernels/Utils.h"
27static std::vector<const Tensor *> joinInputs(
const Tensor *cond,
28 const std::vector<const Tensor *> &inputs)
30 std::vector<const Tensor *>
result{cond};
35If::If(
const Tensor *cond,
const std::vector<const Tensor *> &inputs, std::vector<Tensor *> outputs,
37 :
Kernel(joinInputs(cond, inputs),
std::
move(outputs)), _then_graph(then_graph),
38 _else_graph(else_graph)
90 if (num_elements < 0 ||
static_cast<uint64_t>(num_elements) >
SIZE_MAX / element_size)
92 throw std::runtime_error(
"Integer overflow in size calculation");
95 const int64_t total_size = num_elements * element_size;
97 static_cast<size_t>(total_size));
const std::vector< Tensor * > & getOutputTensors() const
const std::vector< const Tensor * > & getInputTensors() const
const std::vector< Tensor * > & getInputTensors() const
int64_t large_num_elements() const
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
void configure() override
Tensor * output(int index) const
const Tensor * cond() const
void execute() const override
const Tensor * input(int index) const
If(const Tensor *cond, const std::vector< const Tensor * > &inputs, std::vector< Tensor * > outputs, RuntimeGraph *then_graph, RuntimeGraph *else_graph)
#define LUCI_INTERPRETER_CHECK(cond)
size_t getDataTypeSize(DataType data_type)
T must_cast(loco::Node *node)