ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Gather Class Reference

#include <Gather.h>

Collaboration diagram for luci_interpreter::kernels::Gather:

Public Member Functions

 Gather (const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
 
const Tensorparams () const
 
const Tensorindices () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< GatherParams >
const GatherParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< GatherParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const GatherParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< GatherParams >
const GatherParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Gather.h.

Constructor & Destructor Documentation

◆ Gather()

luci_interpreter::kernels::Gather::Gather ( const Tensor params,
const Tensor indices,
Tensor output,
const GatherParams gparams 
)

Definition at line 31 of file Gather.cpp.

33 : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
34{
35}
const Tensor * indices() const
Definition Gather.h:34
const Tensor * params() const
Definition Gather.h:33

References indices(), and params().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Gather::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 37 of file Gather.cpp.

38{
39 if (params()->element_type() == DataType::FLOAT32)
40 {
41 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
42 }
43 else
44 {
45 throw std::runtime_error("luci-intp Gather(1) Unsupported type.");
46 }
47
48 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
49 indices()->element_type() == DataType::S64);
50
51 // refer tensorflow/lite/kernels/gather.cc
52
53 const Shape &params_shape = params()->shape();
54 const Shape &indices_shape = indices()->shape();
55
56 int axis = _params.axis;
57 if (axis < 0)
58 {
59 axis += params_shape.num_dims();
60 }
61 LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
62
63 int batch_dims = _params.batch_dims;
64 // batch_dims should be in range: [-rank(indices), rank(indices)].
65 // Negative batch_dims is added with rank of positions.
66 if (batch_dims < 0)
67 {
68 batch_dims += indices_shape.num_dims();
69 }
70 LUCI_INTERPRETER_CHECK(batch_dims <= axis);
71 LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
72 LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
73 for (int i = 0; i < batch_dims; ++i)
74 {
75 LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
76 }
77
78 const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
79
80 Shape output_shape(num_dimensions);
81 int output_index = 0;
82 for (int i = 0; i < axis; ++i)
83 {
84 output_shape.dim(output_index++) = params_shape.dim(i);
85 }
86 for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
87 {
88 output_shape.dim(output_index++) = indices_shape.dim(i);
89 }
90 for (int i = axis + 1; i < params_shape.num_dims(); ++i)
91 {
92 output_shape.dim(output_index++) = params_shape.dim(i);
93 }
95}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::KernelWithParams< GatherParams >::_params, luci_interpreter::GatherParams::axis, luci_interpreter::GatherParams::batch_dims, luci_interpreter::Shape::dim(), indices(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), output_shape, params(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Gather::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 97 of file Gather.cpp.

98{
99 switch (params()->element_type())
100 {
101 case DataType::FLOAT32:
102 evalFloat();
103 break;
104 default:
105 throw std::runtime_error("luci-intp Gather(2) Unsupported type.");
106 }
107}

References params().

◆ indices()

const Tensor * luci_interpreter::kernels::Gather::indices ( ) const
inline

Definition at line 34 of file Gather.h.

34{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Gather().

◆ output()

Tensor * luci_interpreter::kernels::Gather::output ( ) const
inline

Definition at line 35 of file Gather.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ params()

const Tensor * luci_interpreter::kernels::Gather::params ( ) const
inline

Definition at line 33 of file Gather.h.

33{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Gather().


The documentation for this class was generated from the following files: