ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Pack Class Reference

#include <Pack.h>

Collaboration diagram for luci_interpreter::kernels::Pack:

Public Member Functions

 Pack (std::vector< const Tensor * > inputs, Tensor *output, const PackParams &params)
 
const Tensorinput (int index) const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< PackParams >
const PackParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< PackParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const PackParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< PackParams >
const PackParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Pack.h.

Constructor & Destructor Documentation

◆ Pack()

luci_interpreter::kernels::Pack::Pack ( std::vector< const Tensor * >  inputs,
Tensor output,
const PackParams params 
)

Definition at line 30 of file Pack.cpp.

31 : KernelWithParams<PackParams>(std::move(inputs), {output}, params)
32{
33}
Tensor * output() const
Definition Pack.h:34

References output().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Pack::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 35 of file Pack.cpp.

36{
37 LUCI_INTERPRETER_CHECK(_inputs.size() == static_cast<uint32_t>(params().values_count));
38 const Tensor *t0 = _inputs[0];
39 const int dimension_size = t0->shape().num_dims() + 1;
40 int axis = params().axis;
41 if (axis < 0)
42 {
43 axis += dimension_size;
44 }
45 LUCI_INTERPRETER_CHECK(axis >= 0 && axis <= t0->shape().num_dims());
46
47 if (t0->element_type() != DataType::S32 && t0->element_type() != DataType::FLOAT32 &&
48 t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 &&
49 t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64)
50 {
51 throw std::runtime_error("luci-intp Pack(1) Unsupported type.");
52 }
53
54 for (uint32_t i = 1; i < _inputs.size(); ++i)
55 {
56 const Tensor *tensor = _inputs[i];
57 LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
58 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
59 for (int d = 0; d < t0->shape().num_dims(); ++d)
60 {
61 LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
62 }
63 }
64
65 Shape output_shape(dimension_size);
66 int i = 0;
67 for (int index = 0; index < dimension_size; ++index)
68 {
69 if (index == axis)
70 {
71 output_shape.dim(index) = params().values_count;
72 }
73 else
74 {
75 output_shape.dim(index) = t0->shape().dim(i++);
76 }
77 }
78
79 if (t0->element_type() == DataType::U8 || t0->element_type() == DataType::S8 ||
80 t0->element_type() == DataType::S16)
81 {
82 LUCI_INTERPRETER_CHECK(output()->zero_point() == t0->zero_point());
83 LUCI_INTERPRETER_CHECK(output()->scale() == t0->scale());
84 // Guarantee input/output quantization params match as we do not support
85 // packing quantized tensors.
86 for (int i = 0; i < params().values_count; i++)
87 {
88 LUCI_INTERPRETER_CHECK(_inputs[i]->zero_point() == t0->zero_point());
89 LUCI_INTERPRETER_CHECK(_inputs[i]->scale() == t0->scale());
90 }
91 }
92
94}
int dim(int d) const
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
Definition Shape.h:28

References luci_interpreter::Kernel::_inputs, luci_interpreter::PackParams::axis, luci_interpreter::Shape::dim(), luci_interpreter::Tensor::element_type(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::KernelWithParams< PackParams >::params(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::scale(), luci_interpreter::Tensor::shape(), luci_interpreter::PackParams::values_count, and luci_interpreter::Tensor::zero_point().

◆ execute()

void luci_interpreter::kernels::Pack::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 96 of file Pack.cpp.

97{
98 switch (_inputs[0]->element_type())
99 {
100 case DataType::FLOAT32:
101 evalGeneric<float>();
102 break;
103 case DataType::U8:
104 evalGeneric<uint8_t>();
105 break;
106 case DataType::S8:
107 evalGeneric<int8_t>();
108 break;
109 case DataType::S16:
110 evalGeneric<int16_t>();
111 break;
112 case DataType::S32:
113 evalGeneric<int32_t>();
114 break;
115 case DataType::S64:
116 evalGeneric<int64_t>();
117 break;
118 default:
119 throw std::runtime_error("luci-intp Pack(2) Unsupported type.");
120 }
121}

References luci_interpreter::Kernel::_inputs.

◆ input()

const Tensor * luci_interpreter::kernels::Pack::input ( int  index) const
inline

Definition at line 33 of file Pack.h.

33{ return _inputs[index]; }

References luci_interpreter::Kernel::_inputs.

◆ output()

Tensor * luci_interpreter::kernels::Pack::output ( ) const
inline

Definition at line 34 of file Pack.h.

34{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and Pack().


The documentation for this class was generated from the following files: