ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::CLSplitVEx Class Reference

#include <CLSplitVEx.h>

Collaboration diagram for arm_compute::CLSplitVEx:

Public Member Functions

 CLSplitVEx ()
 
void configure (const ICLTensor *input, const ICLTensor *size_splits, uint32_t split_dim, const std::vector< ICLTensor * > &outputs, unsigned int num_splits)
 
void run () override
 

Detailed Description

Basic function to run CLSplitVKernel

Definition at line 57 of file CLSplitVEx.h.

Constructor & Destructor Documentation

◆ CLSplitVEx()

CLSplitVEx::CLSplitVEx ( )

Default constructor

Definition at line 156 of file CLSplitVEx.cpp.

157 : _input(nullptr), _size_splits(nullptr), _outputs(), _num_splits(0), _slice_functions()
158{
159}

Member Function Documentation

◆ configure()

void CLSplitVEx::configure ( const ICLTensor *  input,
const ICLTensor *  size_splits,
uint32_t  split_dim,
const std::vector< ICLTensor * > &  outputs,
unsigned int  num_splits 
)

Configure the split CL kernel

Parameters
[in]inputThe input tensor to split. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
[in]size_splitsA 1-D tensor containing the number of tensor values per split
[out]outputsA vector containing the output tensor. Data types supported: Same as input The output tensors should match the input tensor dimensions for all shape dimensions apart from the split dimension.
[in]split_dimInteger value representing the input tensor dimension along which to split
[in]num_splitsNumber of splits

Definition at line 161 of file CLSplitVEx.cpp.

163{
164 ARM_COMPUTE_ERROR_ON_NULLPTR(input, size_splits);
165 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(size_splits, outputs, num_splits));
166
167 _input = input;
168 _size_splits = size_splits;
169 _outputs = outputs;
170 _num_splits = num_splits;
171
172 // Create tensor slices
173 _slice_functions.resize(_num_splits);
174
175 // Extract output tensor info
176 std::vector<ITensorInfo *> outputs_info;
177 for (auto &&output : _outputs)
178 {
179 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
180 outputs_info.emplace_back(output->info());
181 }
182
183 // Validate slices
184 ARM_COMPUTE_ERROR_THROW_ON(validate_slices(_input->info(), outputs_info, split_dim));
185
186 // Configure slices
187 configure_slices(_input, _outputs, _slice_functions, split_dim);
188}

◆ run()

void CLSplitVEx::run ( )
override

Definition at line 190 of file CLSplitVEx.cpp.

191{
192 // execute the slices
193 for (unsigned i = 0; i < _outputs.size(); ++i)
194 {
195 _slice_functions[i].run();
196 }
197}

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: