20#include "kernels/Utils.h"
40 const int32_t *muldata = getTensorData<int32_t>(
multiples());
42 for (int32_t dim = 0; dim < num_dim; ++dim)
51 switch (
output()->element_type())
53 case DataType::FLOAT32:
57 throw std::runtime_error(
"luci-intp Tile Unsupported type.");
64template <
typename T,
typename M>
65void CopyMultipleTimes(
const T *in_data, int32_t in_size, M multiplier, T *out_data)
67 for (M i = 0; i < multiplier; ++i)
69 const T *in_end = in_data + in_size;
70 T *new_out_data = std::copy(in_data, in_end, out_data);
72 out_data = new_out_data;
76template <
typename T,
typename M>
77std::pair<int, int> TileOneDimension(
const tflite::RuntimeShape &in_dimensions,
const T *in_data,
78 const M *multiples, T *out_data,
int dimension)
80 if (in_dimensions.DimensionsCount() == 0)
84 return std::make_pair(0, 0);
87 const int dimension_size = in_dimensions.Dims(dimension);
88 if (dimension == in_dimensions.DimensionsCount() - 1)
91 return std::make_pair(dimension_size, dimension_size *
static_cast<int>(multiples[dimension]));
94 int total_stride_size = 0, total_tiled_stride_size = 0;
95 const T *copy_from_data = in_data;
96 T *copy_to_data = out_data;
97 for (
int i = 0; i < dimension_size; ++i)
99 int stride_size = 0, tiled_stride_size = 0;
100 std::tie(stride_size, tiled_stride_size) =
101 TileOneDimension(in_dimensions, copy_from_data, multiples, copy_to_data, dimension + 1);
102 copy_from_data += stride_size;
103 copy_to_data += tiled_stride_size;
104 total_stride_size += stride_size;
105 total_tiled_stride_size += tiled_stride_size;
108 out_data + total_tiled_stride_size);
109 return std::make_pair(total_stride_size,
110 static_cast<int>(total_tiled_stride_size * multiples[dimension]));
115void Tile::evalFloat()
const
118 getTensorData<int32_t>(
multiples()), getTensorData<float>(
output()), 0);
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * multiples() const
const Tensor * input() const
Tile(const Tensor *input, const Tensor *multiplies, Tensor *output)
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void CopyMultipleTimes(const T *in_data, int32_t in_size, M multiplier, T *out_data)
std::pair< int, int > TileOneDimension(const Shape &in_dimensions, const T *in_data, const M *multipliers, T *out_data, int dimension)