ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Unpack Class Reference

#include <Unpack.h>

Collaboration diagram for luci_interpreter::kernels::Unpack:

Public Member Functions

 Unpack (const Tensor *input, std::vector< Tensor * > outputs, const UnpackParams &params)
 
const Tensorinput () const
 
Tensoroutput (int index) const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< UnpackParams >
const UnpackParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< UnpackParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const UnpackParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< UnpackParams >
const UnpackParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Unpack.h.

Constructor & Destructor Documentation

◆ Unpack()

luci_interpreter::kernels::Unpack::Unpack ( const Tensor input,
std::vector< Tensor * >  outputs,
const UnpackParams params 
)

Definition at line 31 of file Unpack.cpp.

32 : KernelWithParams<UnpackParams>({input}, std::move(outputs), params)
33{
34}
const UnpackParams & params() const
Definition Kernel.h:67
const Tensor * input() const
Definition Unpack.h:33

References input().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Unpack::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 36 of file Unpack.cpp.

37{
38 const Shape &input_shape = input()->shape();
39
40 int axis = _params.axis;
41 if (axis < 0)
42 axis += input()->shape().num_dims();
43 assert(axis >= 0 && axis < input_shape.num_dims());
44
45 Shape output_shape(input_shape.num_dims() - 1);
46 int out_index = 0;
47 for (int in_index = 0; in_index < input_shape.num_dims(); ++in_index)
48 {
49 if (in_index != axis)
50 output_shape.dim(out_index++) = input_shape.dim(in_index);
51 }
52
53 for (Tensor *output : _outputs)
54 {
55 assert(output->element_type() == input()->element_type());
57 }
58}
const std::vector< Tensor * > _outputs
Definition Kernel.h:53
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
DataType element_type() const
Definition Tensor.h:105
Tensor * output(int index) const
Definition Unpack.h:34
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::Kernel::_outputs, luci_interpreter::KernelWithParams< UnpackParams >::_params, luci_interpreter::UnpackParams::axis, luci_interpreter::Shape::dim(), luci_interpreter::Tensor::element_type(), input(), luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Unpack::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 70 of file Unpack.cpp.

71{
72 switch (input()->element_type())
73 {
74 case DataType::FLOAT32:
75 return executeImpl<float>();
76 case DataType::U8:
77 return executeImpl<uint8_t>();
78 default:
79 throw std::runtime_error("luci-intp Unpack Unsupported type.");
80 }
81}

References input().

◆ input()

const Tensor * luci_interpreter::kernels::Unpack::input ( ) const
inline

Definition at line 33 of file Unpack.h.

33{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Unpack().

◆ output()

Tensor * luci_interpreter::kernels::Unpack::output ( int  index) const
inline

Definition at line 34 of file Unpack.h.

34{ return _outputs[index]; }
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: