18#include "kernels/Utils.h"
34 switch (
axis()->element_type())
36 case loco::DataType::S32:
37 axis_value = *getTensorData<int32_t>(
axis());
39 case loco::DataType::S64:
40 axis_value =
static_cast<int32_t
>(*getTensorData<int64_t>(
axis()));
43 throw std::runtime_error(
"luci-intp ExpandDims Unsupported type.");
50 axis_value += input_shape.
num_dims() + 1;
62 else if (i == axis_value)
79 const auto *input_data =
input()->
data<
void>();
84 std::memcpy(output_data, input_data, num_elements * element_size);
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * axis() const
void execute() const override
void configure() override
const Tensor * input() const
ExpandDims(const Tensor *input, const Tensor *axis, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
size_t getDataTypeSize(DataType data_type)