ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Transpose Class Reference

#include <Transpose.h>

Collaboration diagram for luci_interpreter::kernels::Transpose:

Public Member Functions

 Transpose (const Tensor *input, const Tensor *perm, Tensor *output)
 
const Tensorinput () const
 
const Tensorperm () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Transpose.h.

Constructor & Destructor Documentation

◆ Transpose()

luci_interpreter::kernels::Transpose::Transpose ( const Tensor input,
const Tensor perm,
Tensor output 
)

Definition at line 31 of file Transpose.cpp.

32 : Kernel({input, perm}, {output})
33{
34}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31
const Tensor * input() const
Definition Transpose.h:33
const Tensor * perm() const
Definition Transpose.h:34

References input(), and perm().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Transpose::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 36 of file Transpose.cpp.

37{
38 // Transpose op only supports 1D-4D input arrays.
39 int dims = input()->shape().num_dims();
40 const int32_t *perm_data = getTensorData<int32_t>(perm());
41
42 assert(input()->shape().num_dims() <= 4);
43 assert(input()->element_type() == output()->element_type());
44
45 assert(perm()->shape().num_dims() == 1);
46 assert(perm()->shape().dim(0) == dims);
47
48 Shape output_shape(dims);
49 for (int i = 0; i < dims; i++)
50 {
51 assert(perm_data[i] < dims && perm_data[i] >= 0);
52 output_shape.dim(i) = input()->shape().dim(perm_data[i]);
53 }
54
56}
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::Shape::dim(), input(), luci_interpreter::Shape::num_dims(), output(), output_shape, perm(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Transpose::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 58 of file Transpose.cpp.

59{
60 tflite::TransposeParams params{};
61 const int32_t *perm_data = getTensorData<int32_t>(perm());
62 const int32_t size = perm()->shape().dim(0);
63 params.perm_count = size;
64 for (int i = 0; i < size; i++)
65 params.perm[i] = perm_data[i];
66 switch (input()->element_type())
67 {
68 case DataType::FLOAT32:
69 tflite::reference_ops::Transpose(params, getTensorShape(input()),
70 getTensorData<float>(input()), getTensorShape(output()),
71 getTensorData<float>(output()));
72 break;
73 case DataType::S64:
74 tflite::reference_ops::Transpose(params, getTensorShape(input()),
75 getTensorData<int64_t>(input()), getTensorShape(output()),
76 getTensorData<int64_t>(output()));
77 break;
78 case DataType::U8:
79 tflite::reference_ops::Transpose(params, getTensorShape(input()),
80 getTensorData<uint8_t>(input()), getTensorShape(output()),
81 getTensorData<uint8_t>(output()));
82 break;
83 default:
84 throw std::runtime_error("luci-intp Transpose Unsupported type.");
85 }
86}
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
int32_t size[5]
Definition Slice.cpp:35

References luci_interpreter::Shape::dim(), luci_interpreter::kernels::getTensorShape(), input(), output(), perm(), luci_interpreter::Tensor::shape(), and size.

◆ input()

const Tensor * luci_interpreter::kernels::Transpose::input ( ) const
inline

Definition at line 33 of file Transpose.h.

33{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Transpose().

◆ output()

Tensor * luci_interpreter::kernels::Transpose::output ( ) const
inline

Definition at line 35 of file Transpose.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ perm()

const Tensor * luci_interpreter::kernels::Transpose::perm ( ) const
inline

Definition at line 34 of file Transpose.h.

34{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Transpose().


The documentation for this class was generated from the following files: