ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALBroadcastTo.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_BROADCAST_TO_COMMON_H
19#define LUCI_INTERPRETER_PAL_BROADCAST_TO_COMMON_H
20
21#include "PALUtils.h"
23
24#include <cmath>
25
27{
28
29template <int N>
30void BroadcastImpl(const NdArrayDesc<N> &input_desc, const uint8_t *input_data,
31 const NdArrayDesc<N> &output_desc, uint8_t *output_data, int indexes[N], int dim,
32 const int last_broadcasting_dim, const uint32_t type_size)
33{
34 // Copy data from input to output.
35 if (dim == last_broadcasting_dim)
36 {
37 int copy_size = output_desc.strides[dim] * type_size;
38 const uint8_t *data_src = input_data + subscriptToIndex(input_desc, indexes) * type_size;
39 uint8_t *data_dst = output_data + subscriptToIndex(output_desc, indexes) * type_size;
40 for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size)
41 {
42 memcpy(data_dst, data_src, copy_size);
43 }
44 return;
45 }
46
47 // Recursive call to find the next broadcasting.
48 for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim]; ++indexes[dim])
49 {
50 BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, dim + 1,
51 last_broadcasting_dim, type_size);
52 }
53
54 // Duplicate data in output tensor.
55 indexes[dim] = 0;
56 if (input_desc.extents[dim] != output_desc.extents[dim])
57 {
58 int copy_size = output_desc.strides[dim] * type_size;
59 uint8_t *data_src = output_data + subscriptToIndex(output_desc, indexes) * type_size;
60 uint8_t *data_dst = data_src + copy_size;
61 for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size)
62 {
63 memcpy(data_dst, data_src, copy_size);
64 }
65 }
66}
67
68template <int N>
69inline void BroadcastTo(const luci_interpreter::RuntimeShape &unextended_input_shape,
70 const uint8_t *input_data,
71 const luci_interpreter::RuntimeShape &unextended_output_shape,
72 uint8_t *output_data, luci_interpreter::DataType data_type)
73{
74 NdArrayDesc<N> input_desc;
75 NdArrayDesc<N> output_desc;
77 &input_desc);
79 &output_desc);
80
81 // Get the last dimension has broadcasting. At this dimension, the data is
82 // copied from input tensor to output tensor.
83 int last_broadcast_dim = -1;
84 for (int i = N - 1; i >= 0; --i)
85 {
86 if (input_desc.extents[i] != output_desc.extents[i])
87 {
88 last_broadcast_dim = i;
89 break;
90 }
91 }
92
93 // If non-broadcasting, just copy data from input to output tensor.
94 if (last_broadcast_dim == -1)
95 {
96 memcpy(output_data, input_data, unextended_input_shape.flatSize() * sizeof(data_type));
97 return;
98 }
99
100 // Broadcasting using memcpy.
101 int indexes[N] = {0};
102 BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0, last_broadcast_dim,
103 luci_interpreter::size(data_type));
104}
105
106} // namespace luci_interpreter_pal
107
108#endif // LUCI_INTERPRETER_PAL_BROADCAST_TO_COMMON_H
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
Definition Tensor.h:95
void copyDimsToDesc(const luci_interpreter::RuntimeShape &input_shape, NdArrayDesc< N > *desc_out)
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
void BroadcastImpl(const NdArrayDesc< N > &input_desc, const uint8_t *input_data, const NdArrayDesc< N > &output_desc, uint8_t *output_data, int indexes[N], int dim, const int last_broadcasting_dim, const uint32_t type_size)
DataType
"scalar" value type
Definition DataType.h:32
uint32_t size(loco::DataType data_type)
Returns the size of the data type.