18#include "kernels/Utils.h"
20#include "PALBroadcastTo.h"
40 if (
tensor->element_type() == DataType::S32)
43 for (
int i = 0;
i <
tensor->shape().num_elements(); ++
i)
51 else if (
tensor->element_type() == DataType::S64)
54 for (
int i = 0;
i <
tensor->shape().num_elements(); ++
i)
105 switch (
input()->element_type())
107 case DataType::FLOAT32:
114 throw std::runtime_error(
"luci-intp BroadcastTo Unsupported type.");
118void BroadcastTo::evalFloat()
const
122 TfLiteType::kTfLiteFloat32);
125void BroadcastTo::evalBool()
const
129 TfLiteType::kTfLiteBool);
void resize(const Shape &new_shape)
const Shape & shape() const
void execute() const override
const Tensor * shape() const
BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output)
const Tensor * input() const
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)