ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::GatherImpl< T, IndicesT > Struct Template Reference

Static Public Member Functions

static void run (const TensorVariant &datav, const TensorVariant &indicesv, const ops::GatherOp &op, mir::TensorVariant &res)
 

Detailed Description

template<typename T, typename IndicesT>
struct mir_interpreter::GatherImpl< T, IndicesT >

Definition at line 27 of file Gather.cpp.

Member Function Documentation

◆ run()

template<typename T , typename IndicesT >
void mir_interpreter::GatherImpl< T, IndicesT >::run ( const TensorVariant datav,
const TensorVariant indicesv,
const ops::GatherOp op,
mir::TensorVariant res 
)
static

Definition at line 34 of file Gather.cpp.

36{
37 const auto &data_shape = datav.getShape();
38 const auto &indices_shape = indicesv.getShape();
39 Tensor<T> data(datav);
40 Tensor<T> output(res);
41 Tensor<IndicesT> indices(indicesv);
42
43 int32_t axis = op.getAxis();
44 if (axis < 0)
45 axis += data_shape.rank();
46 assert(axis >= 0 && axis < data_shape.rank());
47 int32_t axis_size = data_shape.dim(axis);
48 int32_t num_indices = indices_shape.numElements();
49
50 int32_t outer_size = 1;
51 for (int32_t i = 0; i < axis; ++i)
52 outer_size *= data_shape.dim(i);
53
54 int32_t inner_size = 1;
55 for (int32_t i = axis + 1; i < data_shape.rank(); ++i)
56 inner_size *= data_shape.dim(i);
57
58 for (int32_t outer = 0; outer < outer_size; ++outer)
59 {
60 for (int32_t i = 0; i < num_indices; ++i)
61 {
62 auto index = indices.atOffset(i);
63 assert(index >= 0 && index < axis_size);
64 for (int32_t inner = 0; inner < inner_size; inner++)
65 {
66 output.atOffset((outer * num_indices + i) * inner_size + inner) =
67 data.atOffset((outer * axis_size + index) * inner_size + inner);
68 }
69 }
70 }
71}
int32_t getAxis() const
Definition GatherOp.h:46
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References mir::Tensor< T >::atOffset(), mir::ops::GatherOp::getAxis(), and mir::TensorVariant::getShape().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following file: