ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BroadcastTo.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/BroadcastTo.h"
18#include "kernels/Utils.h"
19
20#include "PALBroadcastTo.h"
21
22#include <stdexcept>
23
24namespace luci_interpreter
25{
26namespace kernels
27{
28
29namespace
30{
31
32// TODO Extract this function to Utils.h
33Shape extractShapeFromTensor(const Tensor *tensor)
34{
35 Shape shape(tensor->shape().num_elements());
36
37 // Ensures the shape is 1D tensor
38 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == 1);
39
40 if (tensor->element_type() == DataType::S32)
41 {
42 const auto *shape_data = tensor->data<int32_t>();
43 for (int i = 0; i < tensor->shape().num_elements(); ++i)
44 {
45 // Ensures the dim value of shape is positive.
46 LUCI_INTERPRETER_CHECK(shape_data[i] >= 0);
47
48 shape.dim(i) = shape_data[i];
49 }
50 }
51 else if (tensor->element_type() == DataType::S64)
52 {
53 const auto *shape_data = tensor->data<int64_t>();
54 for (int i = 0; i < tensor->shape().num_elements(); ++i)
55 {
56 // Ensures the dim value of shape is positive.
57 LUCI_INTERPRETER_CHECK(shape_data[i] >= 0);
58
59 shape.dim(i) = static_cast<int32_t>(shape_data[i]);
60 // Check value overflow
61 LUCI_INTERPRETER_CHECK(static_cast<int64_t>(shape.dim(i)) == shape_data[i]);
62 }
63 }
64 else
65 {
67 }
68 return shape;
69}
70
71} // namespace
72
73BroadcastTo::BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output)
74 : Kernel({input, shape}, {output})
75{
76}
77
79{
80 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
81
82 Shape output_shape = extractShapeFromTensor(shape());
83
84 int input_rank = input()->shape().num_dims();
85 int output_rank = output_shape.num_dims();
86
87 // Ensures output rank is not less than input rank
88 LUCI_INTERPRETER_CHECK(input_rank <= output_rank);
89
90 // Check if output shape is broadcastable from input shape
91 // from https://www.tensorflow.org/api_docs/python/tf/broadcast_to
92 // if a tensor has fewer axes than necessary its shape is padded on the left with ones.
93 int extending_rank = output_rank - input_rank;
94 for (int idx = 0; idx < input_rank; ++idx)
95 {
96 LUCI_INTERPRETER_CHECK(input()->shape().dim(idx) == 1 ||
97 input()->shape().dim(idx) == output_shape.dim(extending_rank + idx));
98 }
99
101}
102
104{
105 switch (input()->element_type())
106 {
107 case DataType::FLOAT32:
108 evalFloat();
109 break;
110 case DataType::BOOL:
111 evalBool();
112 break;
113 default:
114 throw std::runtime_error("luci-intp BroadcastTo Unsupported type.");
115 }
116}
117
118void BroadcastTo::evalFloat() const
119{
120 luci_interpreter_pal::BroadcastTo(getTensorShape(input()), getTensorData<char>(input()),
121 getTensorShape(output()), getTensorData<char>(output()),
122 TfLiteType::kTfLiteFloat32);
123}
124
125void BroadcastTo::evalBool() const
126{
127 luci_interpreter_pal::BroadcastTo(getTensorShape(input()), getTensorData<char>(input()),
128 getTensorShape(output()), getTensorData<char>(output()),
129 TfLiteType::kTfLiteBool);
130}
131
132} // namespace kernels
133} // namespace luci_interpreter
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
Definition Shape.h:28