ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
luci_interpreter::kernels::Gather Class Reference

#include <Gather.h>

Collaboration diagram for luci_interpreter::kernels::Gather:

Public Member Functions

 Gather (const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
 
const Tensorparams () const
 
const Tensorindices () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< GatherParams >
const GatherParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< GatherParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const GatherParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< GatherParams >
const GatherParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Gather.h.

Constructor & Destructor Documentation

◆ Gather()

luci_interpreter::kernels::Gather::Gather ( const Tensor params,
const Tensor indices,
Tensor output,
const GatherParams gparams 
)

Definition at line 31 of file Gather.cpp.

33 : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
34{
35}
const Tensor * indices() const
Definition Gather.h:34
const Tensor * params() const
Definition Gather.h:33

References indices(), and params().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Gather::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 37 of file Gather.cpp.

38{
39 if (params()->element_type() == DataType::FLOAT32)
40 {
41 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
42 }
43 else if (params()->element_type() == DataType::S32)
44 {
45 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S32);
46 }
47 else
48 {
49 throw std::runtime_error("luci-intp Gather(1) Unsupported type.");
50 }
51
52 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
53 indices()->element_type() == DataType::S64);
54
55 // refer tensorflow/lite/kernels/gather.cc
56
57 const Shape &params_shape = params()->shape();
58 Shape indices_shape = indices()->shape();
59 {
60 // scalar index is treated as a tensor with the shape of [1]
61 if (indices_shape.num_dims() == 0)
62 {
63 indices_shape = Shape({1});
64 }
65 }
66
67 int axis = _params.axis;
68 if (axis < 0)
69 {
70 axis += params_shape.num_dims();
71 }
72 LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
73
74 int batch_dims = _params.batch_dims;
75 // batch_dims should be in range: [-rank(indices), rank(indices)].
76 // Negative batch_dims is added with rank of positions.
77 if (batch_dims < 0)
78 {
79 batch_dims += indices_shape.num_dims();
80 }
81 LUCI_INTERPRETER_CHECK(batch_dims <= axis);
82 LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
83 LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
84 for (int i = 0; i < batch_dims; ++i)
85 {
86 LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
87 }
88
89 const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
90
91 Shape output_shape(num_dimensions);
92 int output_index = 0;
93 for (int i = 0; i < axis; ++i)
94 {
95 output_shape.dim(output_index++) = params_shape.dim(i);
96 }
97 for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
98 {
99 output_shape.dim(output_index++) = indices_shape.dim(i);
100 }
101 for (int i = axis + 1; i < params_shape.num_dims(); ++i)
102 {
103 output_shape.dim(output_index++) = params_shape.dim(i);
104 }
106}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::KernelWithParams< GatherParams >::_params, luci_interpreter::GatherParams::axis, luci_interpreter::GatherParams::batch_dims, luci_interpreter::Shape::dim(), indices(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), output_shape, params(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Gather::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 108 of file Gather.cpp.

109{
110 switch (params()->element_type())
111 {
112 case DataType::FLOAT32:
113 eval<float>();
114 break;
115 case DataType::S32:
116 eval<int32_t>();
117 break;
118 default:
119 throw std::runtime_error("luci-intp Gather(2) Unsupported type.");
120 }
121}

References params().

◆ indices()

const Tensor * luci_interpreter::kernels::Gather::indices ( ) const
inline

Definition at line 34 of file Gather.h.

34{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Gather().

◆ output()

Tensor * luci_interpreter::kernels::Gather::output ( ) const
inline

Definition at line 35 of file Gather.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ params()

const Tensor * luci_interpreter::kernels::Gather::params ( ) const
inline

Definition at line 33 of file Gather.h.

33{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Gather().


The documentation for this class was generated from the following files: