ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::OneHot Class Reference

#include <OneHot.h>

Collaboration diagram for luci_interpreter::kernels::OneHot:

Public Member Functions

 OneHot (const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams &params)
 
const Tensorindices () const
 
const Tensordepth () const
 
const Tensoron_value () const
 
const Tensoroff_value () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
 OneHot (const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams &params)
 
const Tensorindices () const
 
const Tensordepth () const
 
const Tensoron_value () const
 
const Tensoroff_value () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< OneHotParams >
const OneHotParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< OneHotParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const OneHotParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< OneHotParams >
const OneHotParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file OneHot.h.

Constructor & Destructor Documentation

◆ OneHot() [1/2]

luci_interpreter::kernels::OneHot::OneHot ( const Tensor indices,
const Tensor depth,
const Tensor on_value,
const Tensor off_value,
Tensor output,
const OneHotParams params 
)

Definition at line 69 of file OneHot.cpp.

71 : KernelWithParams<OneHotParams>({indices, depth, on_value, off_value}, {output}, params)
72{
73 // Do nothing
74}
const OneHotParams & params() const
Definition Kernel.h:67
const Tensor * on_value() const
Definition OneHot.h:36
const Tensor * off_value() const
Definition OneHot.h:37
const Tensor * depth() const
Definition OneHot.h:35
const Tensor * indices() const
Definition OneHot.h:34

References depth(), indices(), off_value(), and on_value().

◆ OneHot() [2/2]

luci_interpreter::kernels::OneHot::OneHot ( const Tensor indices,
const Tensor depth,
const Tensor on_value,
const Tensor off_value,
Tensor output,
const OneHotParams params 
)

Member Function Documentation

◆ configure() [1/2]

void luci_interpreter::kernels::OneHot::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 76 of file OneHot.cpp.

77{
78 // check types
79 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32);
80 LUCI_INTERPRETER_CHECK(depth()->element_type() == DataType::S32);
81 LUCI_INTERPRETER_CHECK(on_value()->element_type() == off_value()->element_type());
82 LUCI_INTERPRETER_CHECK(output()->element_type() == on_value()->element_type());
83
84 // check shape dependent parameters
87 LUCI_INTERPRETER_CHECK(depth()->shape().num_elements() == 1);
88 LUCI_INTERPRETER_CHECK(params().axis >= -1 && params().axis <= indices()->shape().num_dims());
89
90 // define parameters that affect the output shape
91 auto const depth_value = getTensorData<int32_t>(depth())[0];
92 auto const &input_shape = indices()->shape();
93 auto const input_dims = input_shape.num_dims();
94 auto const axis = params().axis == -1 ? input_dims : params().axis;
95
96 // define output shape
97 Shape output_shape(input_shape.num_dims() + 1);
98 {
99 for (int32_t d = 0; d < axis; ++d)
100 output_shape.dim(d) = input_shape.dim(d);
101
102 output_shape.dim(axis) = depth_value;
103
104 for (int32_t d = axis + 1; d < output_shape.num_dims(); ++d)
105 output_shape.dim(d) = input_shape.dim(d - 1);
106 }
107
108 // reshape output
110}
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
Definition Shape.h:28

References luci_interpreter::OneHotParams::axis, depth(), indices(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), off_value(), on_value(), output(), output_shape, luci_interpreter::KernelWithParams< OneHotParams >::params(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ configure() [2/2]

void luci_interpreter::kernels::OneHot::configure ( )
overridevirtual

◆ depth() [1/2]

const Tensor * luci_interpreter::kernels::OneHot::depth ( ) const
inline

Definition at line 35 of file OneHot.h.

35{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and OneHot().

◆ depth() [2/2]

const Tensor * luci_interpreter::kernels::OneHot::depth ( ) const
inline

Definition at line 35 of file OneHot.h.

35{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

◆ execute() [1/2]

void luci_interpreter::kernels::OneHot::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 112 of file OneHot.cpp.

113{
114 auto const depth_value = getTensorData<int32_t>(depth())[0];
115 auto const axis = params().axis;
116
117 switch (output()->element_type())
118 {
119 case loco::DataType::FLOAT32:
120 OneHotComputeImpl<float>(indices(), on_value(), off_value(), depth_value, axis, output());
121 break;
122 case loco::DataType::U8:
123 OneHotComputeImpl<uint8_t>(indices(), on_value(), off_value(), depth_value, axis, output());
124 break;
125 case loco::DataType::S16:
126 OneHotComputeImpl<int16_t>(indices(), on_value(), off_value(), depth_value, axis, output());
127 break;
128 default:
129 // TODO Support other data types
130 throw std::runtime_error("Not supported, yet!");
131 break;
132 }
133}

References luci_interpreter::OneHotParams::axis, depth(), indices(), off_value(), on_value(), output(), and luci_interpreter::KernelWithParams< OneHotParams >::params().

◆ execute() [2/2]

void luci_interpreter::kernels::OneHot::execute ( ) const
overridevirtual

◆ indices() [1/2]

const Tensor * luci_interpreter::kernels::OneHot::indices ( ) const
inline

Definition at line 34 of file OneHot.h.

34{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and OneHot().

◆ indices() [2/2]

const Tensor * luci_interpreter::kernels::OneHot::indices ( ) const
inline

Definition at line 34 of file OneHot.h.

34{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

◆ off_value() [1/2]

const Tensor * luci_interpreter::kernels::OneHot::off_value ( ) const
inline

Definition at line 37 of file OneHot.h.

37{ return _inputs[3]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and OneHot().

◆ off_value() [2/2]

const Tensor * luci_interpreter::kernels::OneHot::off_value ( ) const
inline

Definition at line 37 of file OneHot.h.

37{ return _inputs[3]; }

References luci_interpreter::Kernel::_inputs.

◆ on_value() [1/2]

const Tensor * luci_interpreter::kernels::OneHot::on_value ( ) const
inline

Definition at line 36 of file OneHot.h.

36{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and OneHot().

◆ on_value() [2/2]

const Tensor * luci_interpreter::kernels::OneHot::on_value ( ) const
inline

Definition at line 36 of file OneHot.h.

36{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

◆ output() [1/2]

Tensor * luci_interpreter::kernels::OneHot::output ( ) const
inline

Definition at line 39 of file OneHot.h.

39{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ output() [2/2]

Tensor * luci_interpreter::kernels::OneHot::output ( ) const
inline

Definition at line 39 of file OneHot.h.

39{ return _outputs[0]; }

References luci_interpreter::Kernel::_outputs.


The documentation for this class was generated from the following files: