ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::BroadcastTo Class Reference

#include <BroadcastTo.h>

Collaboration diagram for luci_interpreter::kernels::BroadcastTo:

Public Member Functions

 BroadcastTo (const Tensor *input, const Tensor *shape, Tensor *output)
 
const Tensorinput () const
 
const Tensorshape () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file BroadcastTo.h.

Constructor & Destructor Documentation

◆ BroadcastTo()

luci_interpreter::kernels::BroadcastTo::BroadcastTo ( const Tensor input,
const Tensor shape,
Tensor output 
)

Definition at line 73 of file BroadcastTo.cpp.

74 : Kernel({input, shape}, {output})
75{
76}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31

References input(), and shape().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::BroadcastTo::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 78 of file BroadcastTo.cpp.

79{
80 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
81
82 Shape output_shape = extractShapeFromTensor(shape());
83
84 int input_rank = input()->shape().num_dims();
85 int output_rank = output_shape.num_dims();
86
87 // Ensures output rank is not less than input rank
88 LUCI_INTERPRETER_CHECK(input_rank <= output_rank);
89
90 // Check if output shape is broadcastable from input shape
91 // from https://www.tensorflow.org/api_docs/python/tf/broadcast_to
92 // if a tensor has fewer axes than necessary its shape is padded on the left with ones.
93 int extending_rank = output_rank - input_rank;
94 for (int idx = 0; idx < input_rank; ++idx)
95 {
96 LUCI_INTERPRETER_CHECK(input()->shape().dim(idx) == 1 ||
97 input()->shape().dim(idx) == output_shape.dim(extending_rank + idx));
98 }
99
101}
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), and shape().

◆ execute()

void luci_interpreter::kernels::BroadcastTo::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 103 of file BroadcastTo.cpp.

104{
105 switch (input()->element_type())
106 {
107 case DataType::FLOAT32:
108 evalFloat();
109 break;
110 case DataType::BOOL:
111 evalBool();
112 break;
113 default:
114 throw std::runtime_error("luci-intp BroadcastTo Unsupported type.");
115 }
116}

References input().

◆ input()

const Tensor * luci_interpreter::kernels::BroadcastTo::input ( ) const
inline

Definition at line 33 of file BroadcastTo.h.

33{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by BroadcastTo(), configure(), and execute().

◆ output()

Tensor * luci_interpreter::kernels::BroadcastTo::output ( ) const
inline

Definition at line 35 of file BroadcastTo.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ shape()


The documentation for this class was generated from the following files: