19#include "kernels/Utils.h"
21#include "PALSelectV2.h"
37 RuntimeGraph *runtime_graph)
39 using Func =
decltype(luci_interpreter_pal::Select<bool, T>) *;
43 assert(
false &&
"Broadcast not supported now");
47 select_func = luci_interpreter_pal::Select<bool, T>;
51 kernels::getTensorData<bool>(runtime_graph->getDataByTensor(
input_condition)),
53 kernels::getTensorData<T>(runtime_graph->getDataByTensor(
input_x)),
55 kernels::getTensorData<T>(runtime_graph->getDataByTensor(
input_y)),
57 kernels::getTensorData<T>(runtime_graph->getDataByTensor(output)));
93 Tensor::num_elements(
input_y) == 1 && Tensor::num_elements(output) == 1;
95 bool same_shape = Tensor::num_elements(
input_cond) == Tensor::num_elements(
input_x) &&
128 Tensor::num_elements(
input_y) == 1 && Tensor::num_elements(output) == 1;
130 bool same_shape = Tensor::num_elements(
input_cond) == Tensor::num_elements(
input_x) &&
132 bool is_broadcast =
false;
140 case DataType::FLOAT32:
145 assert(
false &&
"Unsupported type.");
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void execute_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
T must_cast(loco::Node *node)