ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::ReverseV2 Class Reference

#include <ReverseV2.h>

Collaboration diagram for luci_interpreter::kernels::ReverseV2:

Public Member Functions

 ReverseV2 (const Tensor *input, const Tensor *axes, Tensor *output)
 
const Tensorinput () const
 
const Tensoraxes () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
 ReverseV2 (const Tensor *input, const Tensor *axes, Tensor *output)
 
const Tensorinput () const
 
const Tensoraxes () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 27 of file ReverseV2.h.

Constructor & Destructor Documentation

◆ ReverseV2() [1/2]

luci_interpreter::kernels::ReverseV2::ReverseV2 ( const Tensor input,
const Tensor axes,
Tensor output 
)

Definition at line 27 of file ReverseV2.cpp.

28 : Kernel({input, axes}, {output})
29{
30}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31
const Tensor * input() const
Definition ReverseV2.h:32
const Tensor * axes() const
Definition ReverseV2.h:33

References axes(), and input().

◆ ReverseV2() [2/2]

luci_interpreter::kernels::ReverseV2::ReverseV2 ( const Tensor input,
const Tensor axes,
Tensor output 
)

Member Function Documentation

◆ axes() [1/2]

const Tensor * luci_interpreter::kernels::ReverseV2::axes ( ) const
inline

Definition at line 33 of file ReverseV2.h.

33{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and ReverseV2().

◆ axes() [2/2]

const Tensor * luci_interpreter::kernels::ReverseV2::axes ( ) const
inline

Definition at line 33 of file ReverseV2.h.

33{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

◆ configure() [1/2]

void luci_interpreter::kernels::ReverseV2::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 32 of file ReverseV2.cpp.

33{
34 assert(axes()->shape().num_dims() == 1);
35 assert(input()->shape().num_dims() >= axes()->shape().num_elements());
36 if (input()->element_type() != DataType::S32 && input()->element_type() != DataType::FLOAT32 &&
37 input()->element_type() != DataType::U8 && input()->element_type() != DataType::S16 &&
38 input()->element_type() != DataType::S64)
39 {
40 throw std::runtime_error("Unsupported input type.");
41 }
42 if (axes()->element_type() != DataType::S32)
43 {
44 throw std::runtime_error("Unsupported axes type.");
45 }
46 if (axes()->shape().num_elements() > 1)
47 {
48 throw std::runtime_error("Current implementation does not support more than 1 axis.");
49 }
50 int axis_value = getTensorData<int32_t>(axes())[0];
51 if (axis_value < 0 || axis_value >= input()->shape().num_dims())
52 {
53 throw std::runtime_error("Invalid axes value");
54 }
55 assert(input()->element_type() == output()->element_type());
56
57 output()->resize(input()->shape());
58}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59

References axes(), input(), output(), and luci_interpreter::Tensor::resize().

◆ configure() [2/2]

void luci_interpreter::kernels::ReverseV2::configure ( )
overridevirtual

◆ execute() [1/2]

void luci_interpreter::kernels::ReverseV2::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 60 of file ReverseV2.cpp.

61{
62 int axis_value = getTensorData<int32_t>(axes())[0];
63 switch (output()->element_type())
64 {
65 case DataType::FLOAT32:
66 tflite::reference_ops::Reverse<float>(axis_value, getTensorShape(input()),
67 getTensorData<float>(input()), getTensorShape(output()),
68 getTensorData<float>(output()));
69 break;
70 case DataType::U8:
71 tflite::reference_ops::Reverse<uint8_t>(
72 axis_value, getTensorShape(input()), getTensorData<uint8_t>(input()),
73 getTensorShape(output()), getTensorData<uint8_t>(output()));
74 break;
75 default:
76 throw std::runtime_error("Unsupported output type");
77 }
78}
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194

References axes(), luci_interpreter::kernels::getTensorShape(), input(), and output().

◆ execute() [2/2]

void luci_interpreter::kernels::ReverseV2::execute ( ) const
overridevirtual

◆ input() [1/2]

const Tensor * luci_interpreter::kernels::ReverseV2::input ( ) const
inline

Definition at line 32 of file ReverseV2.h.

32{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and ReverseV2().

◆ input() [2/2]

const Tensor * luci_interpreter::kernels::ReverseV2::input ( ) const
inline

Definition at line 32 of file ReverseV2.h.

32{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

◆ output() [1/2]

Tensor * luci_interpreter::kernels::ReverseV2::output ( ) const
inline

Definition at line 34 of file ReverseV2.h.

34{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ output() [2/2]

Tensor * luci_interpreter::kernels::ReverseV2::output ( ) const
inline

Definition at line 34 of file ReverseV2.h.

34{ return _outputs[0]; }

References luci_interpreter::Kernel::_outputs.


The documentation for this class was generated from the following files: