ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::ConcatImpl< T > Struct Template Reference

Static Public Member Functions

static void run (const std::vector< std::reference_wrapper< const mir::TensorVariant > > &inputs, int axis, mir::TensorVariant &output)
 

Detailed Description

template<typename T>
struct mir_interpreter::ConcatImpl< T >

Definition at line 27 of file Concat.cpp.

Member Function Documentation

◆ run()

template<typename T >
void mir_interpreter::ConcatImpl< T >::run ( const std::vector< std::reference_wrapper< const mir::TensorVariant > > &  inputs,
int  axis,
mir::TensorVariant output 
)
static

Definition at line 34 of file Concat.cpp.

36{
37 const auto &output_shape = output.getShape();
38 const size_t inputs_count = inputs.size();
39 const int32_t concat_dims = output_shape.rank();
40 int64_t concat_size = 0;
41 for (size_t i = 0; i < inputs_count; i++)
42 {
43 const auto &input_shape = inputs[i].get().getShape();
44 assert(input_shape.rank() == concat_dims);
45 for (int32_t j = 0; j < concat_dims; j++)
46 {
47 if (j != axis)
48 {
49 assert(input_shape.dim(j) == output_shape.dim(j));
50 }
51 }
52 concat_size += input_shape.dim(axis);
53 }
54 assert(concat_size == output_shape.dim(axis));
55 // Outer size before axis
56 int32_t outer_size = 1;
57 for (int32_t i = 0; i < axis; i++)
58 outer_size *= output_shape.dim(i);
59 // Inner size after axis
60 int32_t base_inner_size = 1;
61 for (int32_t i = axis + 1; i < concat_dims; i++)
62 base_inner_size *= output_shape.dim(i);
63 // flatten = outer_size * dim(axis) * base_inner_size;
64 std::vector<int32_t> copy_sizes;
65 std::vector<char *> input_ptrs;
66 for (size_t i = 0; i < inputs_count; i++)
67 {
68 const auto input_shape = inputs[i].get().getShape();
69 copy_sizes.push_back(input_shape.dim(axis) * base_inner_size);
70 input_ptrs.push_back(inputs[i].get().atOffset(0));
71 }
72
73 char *output_ptr = output.atOffset(0);
74 const size_t elem_size = inputs[0].get().getElementSize();
75 for (int32_t i = 0; i < outer_size; i++)
76 {
77 for (size_t j = 0; j < inputs_count; j++)
78 {
79 std::memcpy(output_ptr, input_ptrs[j], copy_sizes[j] * elem_size);
80 output_ptr += copy_sizes[j] * elem_size;
81 input_ptrs[j] += copy_sizes[j] * elem_size;
82 }
83 }
84}
const luci_interpreter::RuntimeShape output_shape
KnobTrait< K >::ValueType get(void)

References output_shape.

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following file: