19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
44 auto t_shape =
t()->
shape();
45 auto e_shape =
e()->
shape();
56 case DataType::FLOAT32:
66 throw std::runtime_error(
"luci-intp SelectV2 unsupported type.");
70template <
typename T>
void SelectV2::evaluate()
const
73 const auto condition_data = getTensorData<bool>(
condition());
75 const auto t_data = getTensorData<T>(
t());
77 const auto e_data = getTensorData<T>(
e());
79 auto output_data = getTensorData<T>(
output());
81 tflite::reference_ops::BroadcastSelect5DSlow<bool, T>(
82 condition_shape, condition_data, t_shape, t_data, e_shape, e_data,
output_shape, output_data);
void resize(const Shape &new_shape)
const Shape & shape() const
DataType element_type() const
const Tensor * condition() const
SelectV2(const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output)
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)