ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::BatchToSpaceND Class Reference

#include <BatchToSpaceND.h>

Collaboration diagram for luci_interpreter::kernels::BatchToSpaceND:

Public Member Functions

 BatchToSpaceND (const Tensor *input, const Tensor *block_shape, const Tensor *crops, Tensor *output)
 
const Tensorinput () const
 
const Tensorblock_shape () const
 
const Tensorcrops () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 27 of file BatchToSpaceND.h.

Constructor & Destructor Documentation

◆ BatchToSpaceND()

luci_interpreter::kernels::BatchToSpaceND::BatchToSpaceND ( const Tensor input,
const Tensor block_shape,
const Tensor crops,
Tensor output 
)

Definition at line 37 of file BatchToSpaceND.cpp.

40{
41}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31

References block_shape(), crops(), and input().

Member Function Documentation

◆ block_shape()

const Tensor * luci_interpreter::kernels::BatchToSpaceND::block_shape ( ) const
inline

Definition at line 34 of file BatchToSpaceND.h.

34{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by BatchToSpaceND(), configure(), and execute().

◆ configure()

void luci_interpreter::kernels::BatchToSpaceND::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 43 of file BatchToSpaceND.cpp.

44{
45
46 const auto *block_shape_data = block_shape()->data<int32_t>();
47 const auto *crops_data = crops()->data<int32_t>();
48 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= kInputMinDimensionNum);
49 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() <= kInputMaxDimensionNum);
50 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
51
52 int spatial_dims_num = input()->shape().num_dims() - 2;
53
54 LUCI_INTERPRETER_CHECK(block_shape()->shape().num_dims() == 1);
55 LUCI_INTERPRETER_CHECK(block_shape()->shape().dim(0) == spatial_dims_num);
56
57 LUCI_INTERPRETER_CHECK(crops()->shape().num_dims() == 2);
58 LUCI_INTERPRETER_CHECK(crops()->shape().dim(0) == spatial_dims_num);
59 LUCI_INTERPRETER_CHECK(crops()->shape().dim(1) == 2);
60 for (int i = 0; i < spatial_dims_num * 2; ++i)
61 {
62 LUCI_INTERPRETER_CHECK(crops_data[i] >= 0);
63 }
64
65 Shape output_shape = Shape(input()->shape().num_dims());
66 int output_batch_size = input()->shape().dim(0);
67 for (int i = 0; i < spatial_dims_num; ++i)
68 {
69 LUCI_INTERPRETER_CHECK(output_batch_size % block_shape_data[i] == 0);
70 output_batch_size = output_batch_size / block_shape_data[i];
71 output_shape.dim(i + 1) =
72 input()->shape().dim(i + 1) * block_shape_data[i] - crops_data[i * 2] - crops_data[i * 2 + 1];
73 }
74
75 output_shape.dim(0) = output_batch_size;
76 output_shape.dim(input()->shape().num_dims() - 1) =
77 input()->shape().dim(input()->shape().num_dims() - 1);
79}
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const T * data() const
Definition Tensor.h:127
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References block_shape(), crops(), luci_interpreter::Tensor::data(), luci_interpreter::Shape::dim(), input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ crops()

const Tensor * luci_interpreter::kernels::BatchToSpaceND::crops ( ) const
inline

Definition at line 35 of file BatchToSpaceND.h.

35{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by BatchToSpaceND(), configure(), and execute().

◆ execute()

void luci_interpreter::kernels::BatchToSpaceND::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 81 of file BatchToSpaceND.cpp.

82{
83 switch (input()->element_type())
84 {
85 case DataType::FLOAT32:
86 luci_interpreter_pal::BatchToSpaceND(
87 getTensorShape(input()), getTensorData<float>(input()), getTensorShape(block_shape()),
88 getTensorData<int32_t>(block_shape()), getTensorShape(crops()),
89 getTensorData<int32_t>(crops()), getTensorShape(output()), getTensorData<float>(output()));
90 break;
91 case DataType::U8:
92 luci_interpreter_pal::BatchToSpaceND(
93 getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(block_shape()),
94 getTensorData<int32_t>(block_shape()), getTensorShape(crops()),
95 getTensorData<int32_t>(crops()), getTensorShape(output()),
96 getTensorData<uint8_t>(output()));
97 break;
98 default:
99 throw std::runtime_error("luci-intp BatchToSpaceND Unsupported type.");
100 }
101}
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194

References block_shape(), crops(), luci_interpreter::kernels::getTensorShape(), input(), and output().

◆ input()

const Tensor * luci_interpreter::kernels::BatchToSpaceND::input ( ) const
inline

Definition at line 33 of file BatchToSpaceND.h.

33{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by BatchToSpaceND(), configure(), and execute().

◆ output()

Tensor * luci_interpreter::kernels::BatchToSpaceND::output ( ) const
inline

Definition at line 36 of file BatchToSpaceND.h.

36{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().


The documentation for this class was generated from the following files: