18#include "kernels/Utils.h"
20#include "PALBroadcastTo.h"
40 if (
tensor->element_type() == DataType::S32)
42 const auto *shape_data =
tensor->data<int32_t>();
43 for (
int i = 0; i <
tensor->shape().num_elements(); ++i)
48 shape.dim(i) = shape_data[i];
51 else if (
tensor->element_type() == DataType::S64)
53 const auto *shape_data =
tensor->data<int64_t>();
54 for (
int i = 0; i <
tensor->shape().num_elements(); ++i)
59 shape.dim(i) =
static_cast<int32_t
>(shape_data[i]);
93 int extending_rank = output_rank - input_rank;
94 for (
int idx = 0; idx < input_rank; ++idx)
105 switch (
input()->element_type())
107 case DataType::FLOAT32:
114 throw std::runtime_error(
"luci-intp BroadcastTo Unsupported type.");
118void BroadcastTo::evalFloat()
const
122 TfLiteType::kTfLiteFloat32);
125void BroadcastTo::evalBool()
const
129 TfLiteType::kTfLiteBool);
void resize(const Shape &new_shape)
const Shape & shape() const
void execute() const override
const Tensor * shape() const
BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output)
const Tensor * input() const
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)