20#include "kernels/Utils.h"
31static Shape extractShapeFromTensor(
const Tensor *tensor)
34 if (
tensor->element_type() == DataType::S32)
36 const auto *shape_data =
tensor->data<int32_t>();
37 for (
int i = 0; i <
tensor->shape().num_elements(); ++i)
39 shape.dim(i) = shape_data[i];
42 else if (
tensor->element_type() == DataType::S64)
44 const auto *shape_data =
tensor->data<int64_t>();
45 for (
int i = 0; i <
tensor->shape().num_elements(); ++i)
47 shape.dim(i) =
static_cast<int32_t
>(shape_data[i]);
59 const int32_t num_input_elements = input_shape.num_elements();
60 int32_t num_output_elements = 1;
61 int unknown_dim_index = -1;
67 assert(unknown_dim_index == -1);
68 unknown_dim_index = i;
72 num_output_elements *= value;
75 if (unknown_dim_index != -1)
77 output_shape->dim(unknown_dim_index) = num_input_elements / num_output_elements;
78 num_output_elements *=
output_shape->dim(unknown_dim_index);
98 const auto *input_data =
input()->
data<
void>();
103 std::memcpy(output_data, input_data, num_elements * element_size);
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
Reshape(const Tensor *input, const Tensor *shape, Tensor *output)
const Tensor * shape() const
void configure() override
void execute() const override
const Tensor * input() const
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
size_t getDataTypeSize(DataType data_type)