ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Select Class Reference

#include <Select.h>

Collaboration diagram for luci_interpreter::kernels::Select:

Public Member Functions

 Select (const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output)
 
const Tensorcondition () const
 
const Tensort () const
 
const Tensore () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Select.h.

Constructor & Destructor Documentation

◆ Select()

luci_interpreter::kernels::Select::Select ( const Tensor cond,
const Tensor t,
const Tensor e,
Tensor output 
)

Definition at line 33 of file Select.cpp.

34 : Kernel({condition, t, e}, {output})
35{
36 _has_low_rank_input_condition = false;
37}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31
const Tensor * e() const
Definition Select.h:35
const Tensor * t() const
Definition Select.h:34
const Tensor * condition() const
Definition Select.h:33

References condition(), e(), and t().

Member Function Documentation

◆ condition()

const Tensor * luci_interpreter::kernels::Select::condition ( ) const
inline

Definition at line 33 of file Select.h.

33{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Select().

◆ configure()

void luci_interpreter::kernels::Select::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 39 of file Select.cpp.

40{
41 LUCI_INTERPRETER_CHECK(condition()->element_type() == DataType::BOOL);
42 LUCI_INTERPRETER_CHECK(t()->element_type() == e()->element_type());
43 LUCI_INTERPRETER_CHECK(t()->element_type() == output()->element_type());
44
45 auto cond_shape = condition()->shape();
46 auto cond_num_dims = cond_shape.num_dims();
47 auto t_shape = t()->shape();
48
49 bool is_input_condition_scalar = cond_num_dims == 0;
50 bool has_rank_one_input_condition = cond_num_dims == 1 && cond_shape.dim(0) == t_shape.dim(0);
51
52 _has_low_rank_input_condition = is_input_condition_scalar || has_rank_one_input_condition;
53
54 output()->resize(calculateShapeForBroadcast(t()->shape(), e()->shape()));
55}
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
Definition Utils.cpp:204

References luci_interpreter::kernels::calculateShapeForBroadcast(), condition(), luci_interpreter::Shape::dim(), e(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), and t().

◆ e()

const Tensor * luci_interpreter::kernels::Select::e ( ) const
inline

Definition at line 35 of file Select.h.

35{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Select().

◆ execute()

void luci_interpreter::kernels::Select::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 57 of file Select.cpp.

58{
59 switch (t()->element_type())
60 {
61 case DataType::FLOAT32:
62 evalFloat();
63 break;
64 default:
65 throw std::runtime_error("luci-intp Select unsupported type.");
66 }
67}

References t().

◆ output()

Tensor * luci_interpreter::kernels::Select::output ( ) const
inline

Definition at line 36 of file Select.h.

36{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ t()

const Tensor * luci_interpreter::kernels::Select::t ( ) const
inline

Definition at line 34 of file Select.h.

34{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Select().


The documentation for this class was generated from the following files: