19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
36 _has_low_rank_input_condition =
false;
46 auto cond_num_dims = cond_shape.
num_dims();
47 auto t_shape =
t()->
shape();
49 bool is_input_condition_scalar = cond_num_dims == 0;
50 bool has_rank_one_input_condition = cond_num_dims == 1 && cond_shape.
dim(0) == t_shape.dim(0);
52 _has_low_rank_input_condition = is_input_condition_scalar || has_rank_one_input_condition;
59 switch (
t()->element_type())
61 case DataType::FLOAT32:
65 throw std::runtime_error(
"luci-intp Select unsupported type.");
69void Select::evalFloat()
const
72 const auto condition_data = getTensorData<bool>(
condition());
74 const auto t_data = getTensorData<float>(
t());
76 const auto e_data = getTensorData<float>(
e());
78 auto output_data = getTensorData<float>(
output());
80 if (_has_low_rank_input_condition)
void resize(const Shape &new_shape)
const Shape & shape() const
void configure() override
Select(const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output)
void execute() const override
const Tensor * condition() const
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void Select(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)
void RankOneSelect(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)