ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Fill.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/Fill.h"
18#include "kernels/Utils.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20
21namespace luci_interpreter
22{
23namespace kernels
24{
25
26Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
27 : Kernel({dims, value}, {output})
28{
29}
30
31template <typename T> void Fill::configureShape()
32{
33 const auto dims_data = getTensorData<T>(dims());
34 Shape output_shape(dims()->shape().dim(0));
35
36 for (int i = 0; i < output_shape.num_dims(); ++i)
37 {
38 T data = dims_data[i];
39 if (data < 0)
40 throw std::runtime_error("Fill dimensions must be >= 0");
41
42 output_shape.dim(i) = data;
43 }
44
46}
47
49{
50 const auto dims_shape = dims()->shape();
51 const auto value_shape = value()->shape();
52
53 // Make sure the 1st input tensor is 1-D
54 LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
55
56 // Make sure the 1st input tensor is int32 or int64
57 LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
58 dims()->element_type() == DataType::S64);
59
60 // Make sure the 2nd input tensor is a scalar
61 LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
62
63 // Check zero point and scale for S16 and S8
64 if (value()->element_type() == loco::DataType::S16 or
65 value()->element_type() == loco::DataType::S8)
66 {
67 LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
68 LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
69
70 if (value()->element_type() == loco::DataType::S16)
71 LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
72 }
73 // Resize output
74 switch (dims()->element_type())
75 {
76 case DataType::S32:
77 configureShape<int32_t>();
78 break;
79 case DataType::S64:
80 configureShape<int64_t>();
81 break;
82 default:
83 throw std::runtime_error("luci-intp Fill(1) Unsupported type.");
84 }
85}
86
87void Fill::execute() const
88{
89 switch (output()->element_type())
90 {
91 case DataType::S8:
92 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
93 getTensorShape(output()), getTensorData<int8_t>(output()));
94 break;
95 case DataType::S16:
96 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
97 getTensorShape(output()), getTensorData<int16_t>(output()));
98 break;
99 case DataType::S32:
100 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
101 getTensorShape(output()), getTensorData<int32_t>(output()));
102 break;
103 case DataType::S64:
104 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
105 getTensorShape(output()), getTensorData<int64_t>(output()));
106 break;
107 case DataType::FLOAT32:
108 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
109 getTensorShape(output()), getTensorData<float>(output()));
110 break;
111 default:
112 throw std::runtime_error("luci-intp Fill(2) Unsupported type.");
113 }
114}
115
116} // namespace kernels
117} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
Fill(const Tensor *dims, const Tensor *value, Tensor *output)
Definition Fill.cpp:26
Tensor * output() const
Definition Fill.h:35
void configure() override
Definition Fill.cpp:48
const Tensor * dims() const
Definition Fill.h:33
const Tensor * value() const
Definition Fill.h:34
void execute() const override
Definition Fill.cpp:87
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
const T * data(const std::vector< T, Alloc > &v)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
Definition Shape.h:28