ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BinaryOpCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
19#define LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
20
21#include "tensorflow/lite/kernels/internal/common.h"
22#include "tensorflow/lite/kernels/internal/types.h"
23
24namespace luci_interpreter
25{
26namespace kernels
27{
28
29// Derived from tensorflow/lite/kernels/internal/reference/maximum_minimum.h (v2.3.0).
30template <typename T, typename Op, int N = 5>
31void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape,
32 const T *input1_data,
33 const tflite::RuntimeShape &unextended_input2_shape,
34 const T *input2_data,
35 const tflite::RuntimeShape &unextended_output_shape, T *output_data,
36 Op op)
37{
38 if (unextended_input1_shape == unextended_input2_shape)
39 {
40 const int flat_size = tflite::MatchingElementsSize(
41 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
42 for (int i = 0; i < flat_size; ++i)
43 {
44 output_data[i] = op(input1_data[i], input2_data[i]);
45 }
46 }
47 else
48 {
49 assert(unextended_input1_shape.DimensionsCount() <= N);
50 assert(unextended_input2_shape.DimensionsCount() <= N);
51 assert(unextended_output_shape.DimensionsCount() <= N);
52
53 tflite::NdArrayDesc<N> desc1{};
54 tflite::NdArrayDesc<N> desc2{};
55 tflite::NdArrayDesc<N> output_desc{};
56 tflite::NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape,
57 &desc1, &desc2);
58 tflite::CopyDimsToDesc(tflite::RuntimeShape::ExtendedShape(N, unextended_output_shape),
59 &output_desc);
60
61 auto fn = [&](int indexes[N]) {
62 output_data[SubscriptToIndex(output_desc, indexes)] =
63 op(input1_data[SubscriptToIndex(desc1, indexes)],
64 input2_data[SubscriptToIndex(desc2, indexes)]);
65 };
66 tflite::NDOpsHelper<N>(output_desc, fn);
67 }
68}
69
70} // namespace kernels
71} // namespace luci_interpreter
72
73#endif // LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
int SubscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
Definition NDArray.h:54
NdArrayDesc< 4 > desc1
NdArrayDesc< 4 > desc2
void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape, const T *input1_data, const tflite::RuntimeShape &unextended_input2_shape, const T *input2_data, const tflite::RuntimeShape &unextended_output_shape, T *output_data, Op op)