18#include "kernels/Utils.h"
27static std::vector<const Tensor *> joinInputs(
const Tensor *cond,
28 const std::vector<const Tensor *> &inputs)
30 std::vector<const Tensor *>
result{cond};
35If::If(
const Tensor *cond,
const std::vector<const Tensor *> &inputs, std::vector<Tensor *> outputs,
37 :
Kernel(joinInputs(cond, inputs),
std::
move(outputs)), _then_graph(then_graph),
38 _else_graph(else_graph)
89 num_elements * element_size);
const std::vector< Tensor * > & getOutputTensors() const
const std::vector< const Tensor * > & getInputTensors() const
const std::vector< Tensor * > & getInputTensors() const
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
void configure() override
Tensor * output(int index) const
const Tensor * cond() const
void execute() const override
const Tensor * input(int index) const
If(const Tensor *cond, const std::vector< const Tensor * > &inputs, std::vector< Tensor * > outputs, RuntimeGraph *then_graph, RuntimeGraph *else_graph)
#define LUCI_INTERPRETER_CHECK(cond)
size_t getDataTypeSize(DataType data_type)
T must_cast(loco::Node *node)