ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Concatenation Class Reference

#include <Concatenation.h>

Collaboration diagram for luci_interpreter::kernels::Concatenation:

Public Member Functions

 Concatenation (std::vector< const Tensor * > inputs, Tensor *output, const ConcatenationParams &params)
 
const Tensorinput (int index) const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< ConcatenationParams >
const ConcatenationParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< ConcatenationParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const ConcatenationParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< ConcatenationParams >
const ConcatenationParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Concatenation.h.

Constructor & Destructor Documentation

◆ Concatenation()

luci_interpreter::kernels::Concatenation::Concatenation ( std::vector< const Tensor * >  inputs,
Tensor output,
const ConcatenationParams params 
)

Definition at line 30 of file Concatenation.cpp.

32 : KernelWithParams<ConcatenationParams>(std::move(inputs), {output}, params)
33{
34}
const ConcatenationParams & params() const
Definition Kernel.h:67

References output().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Concatenation::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 36 of file Concatenation.cpp.

37{
38 const int num_inputs = _inputs.size();
39 LUCI_INTERPRETER_CHECK(num_inputs > 0);
40 const Tensor *t0 = _inputs[0];
41
42 // TODO: Support concat with fused activation function
44
45 int axis = _params.axis;
46 if (axis < 0)
47 axis += t0->shape().num_dims();
48 LUCI_INTERPRETER_CHECK(axis >= 0 && axis < t0->shape().num_dims());
49
50 int32_t sum_axis = t0->shape().dim(axis);
51 for (int i = 1; i < num_inputs; ++i)
52 {
53 const Tensor *tensor = _inputs[i];
54 LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
55 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
56 for (int d = 0; d < t0->shape().num_dims(); ++d)
57 {
58 if (d == axis)
59 {
60 sum_axis += tensor->shape().dim(axis);
61 }
62 else
63 {
64 LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
65 }
66 }
67 }
68
69 Shape output_shape = t0->shape();
70 output_shape.dim(axis) = sum_axis;
71
72 // If input tensors are INT8 type then quantization parameters of all input tensors and the output
73 // should be the same
74 for (auto current_tensor : _inputs)
75 {
76 if (current_tensor->element_type() == DataType::S8)
77 {
78 LUCI_INTERPRETER_CHECK(current_tensor->quantized_dimension() ==
79 output()->quantized_dimension());
80
81 LUCI_INTERPRETER_CHECK(current_tensor->zero_points().size() ==
82 current_tensor->scales().size());
83 LUCI_INTERPRETER_CHECK(current_tensor->zero_points() == output()->zero_points());
84 LUCI_INTERPRETER_CHECK(current_tensor->scales() == output()->scales());
85 }
86 }
88}
int dim(int d) const
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::Kernel::_inputs, luci_interpreter::KernelWithParams< ConcatenationParams >::_params, luci_interpreter::ConcatenationParams::axis, luci_interpreter::Shape::dim(), luci_interpreter::Tensor::element_type(), LUCI_INTERPRETER_CHECK, luci::NONE, luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::KernelWithParams< ConcatenationParams >::params(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Concatenation::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 90 of file Concatenation.cpp.

91{
92 switch (_inputs[0]->element_type())
93 {
94 case DataType::FLOAT32:
95 evalGeneric<float>();
96 break;
97 case DataType::U8:
98 evalQuantized();
99 break;
100 case DataType::S8:
101 evalGeneric<int8_t>();
102 break;
103 case DataType::S32:
104 evalGeneric<int32_t>();
105 break;
106 case DataType::S64:
107 evalGeneric<int64_t>();
108 break;
109 default:
110 throw std::runtime_error("luci-intp Concatenation Unsupported type.");
111 }
112}

References luci_interpreter::Kernel::_inputs.

◆ input()

const Tensor * luci_interpreter::kernels::Concatenation::input ( int  index) const
inline

Definition at line 34 of file Concatenation.h.

34{ return _inputs[index]; }
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References luci_interpreter::Kernel::_inputs.

◆ output()

Tensor * luci_interpreter::kernels::Concatenation::output ( ) const
inline

Definition at line 35 of file Concatenation.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by Concatenation(), and configure().


The documentation for this class was generated from the following files: