36{
38 const size_t inputs_count =
inputs.size();
40 int64_t concat_size = 0;
41 for (size_t i = 0; i < inputs_count; i++)
42 {
43 const auto &input_shape =
inputs[i].get().getShape();
44 assert(input_shape.rank() == concat_dims);
45 for (int32_t j = 0; j < concat_dims; j++)
46 {
47 if (j != axis)
48 {
50 }
51 }
52 concat_size += input_shape.dim(axis);
53 }
55
56 int32_t outer_size = 1;
57 for (int32_t i = 0; i < axis; i++)
59
60 int32_t base_inner_size = 1;
61 for (int32_t i = axis + 1; i < concat_dims; i++)
63
64 std::vector<int32_t> copy_sizes;
65 std::vector<char *> input_ptrs;
66 for (size_t i = 0; i < inputs_count; i++)
67 {
68 const auto input_shape =
inputs[i].get().getShape();
69 copy_sizes.push_back(input_shape.dim(axis) * base_inner_size);
70 input_ptrs.push_back(inputs[i].
get().atOffset(0));
71 }
72
73 char *output_ptr =
output.atOffset(0);
74 const size_t elem_size =
inputs[0].get().getElementSize();
75 for (int32_t i = 0; i < outer_size; i++)
76 {
77 for (size_t j = 0; j < inputs_count; j++)
78 {
79 std::memcpy(output_ptr, input_ptrs[j], copy_sizes[j] * elem_size);
80 output_ptr += copy_sizes[j] * elem_size;
81 input_ptrs[j] += copy_sizes[j] * elem_size;
82 }
83 }
84}
const luci_interpreter::RuntimeShape output_shape
KnobTrait< K >::ValueType get(void)