ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Slice.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/Slice.h"
18#include "Utils.h"
19#include "PALSlice.h"
20
21#include <cassert>
22#include <cstring>
23
24namespace luci_interpreter
25{
26
27namespace kernels
28{
29const int max_dim = 4;
30
31Slice::Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
32 : Kernel({input, begin, size}, {output})
33{
34}
35
36template <typename T>
38{
39 Shape output_shape = Shape(input->shape().num_dims());
40 for (int idx = 0; idx < input->shape().num_dims(); idx++)
41 {
42 T size_value = getTensorData<T>(size)[idx];
43 if (size_value < 0)
44 {
45 if (size_value != -1)
46 {
47 throw std::runtime_error("Invalid size.");
48 }
49 size_value = input->shape().dim(idx) - getTensorData<T>(begin)[idx];
50 }
51 else
52 {
53 if (input->shape().dim(idx) < getTensorData<T>(begin)[idx] + size_value)
54 {
55 throw std::runtime_error("Invalid begin and size.");
56 }
57 }
58 output_shape.dim(idx) = static_cast<int>(size_value);
59 }
60 return output_shape;
61}
62
63template <typename T>
64void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size,
65 std::vector<int> *begins, std::vector<int> *sizes)
66{
67 for (int idx = dimensions - 1; idx >= 0; --idx)
68 {
69 begins->push_back(getTensorData<T>(begin)[idx]);
70 sizes->push_back(getTensorData<T>(size)[idx]);
71 }
72}
73
75{
76 assert(input()->element_type() == output()->element_type());
77 assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
78 assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
79 assert(begin()->shape().num_dims() == 1);
80 assert(size()->shape().num_dims() == 1);
81 assert(input()->shape().num_dims() <= max_dim);
82
83 if (begin()->element_type() == DataType::S32)
84 {
85 output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
86 }
87 else if (begin()->element_type() == DataType::S64)
88 {
89 output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
90 }
91 else
92 {
93 throw std::runtime_error("luci-intp Slice Unsupported type.");
94 }
95}
96
97void Slice::execute() const
98{
99 std::vector<int> begins;
100 begins.reserve(max_dim);
101 std::vector<int> sizes;
102 sizes.reserve(max_dim);
103 if (begin()->element_type() == DataType::S32)
104 {
105 getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
106 }
107 else if (begin()->element_type() == DataType::S64)
108 {
109 getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
110 }
111 else
112 {
113 throw std::runtime_error("Unsupported begin type.");
114 }
115 for (int i = input()->shape().num_dims(); i < max_dim; ++i)
116 {
117 begins.push_back(0);
118 sizes.push_back(1);
119 }
120
121 assert(begins.size() == 4);
122 assert(sizes.size() == 4);
123 tflite::SliceParams op_params{};
124 op_params.begin_count = 4;
125 op_params.size_count = 4;
126 for (int i = 0; i < 4; i++)
127 {
128 op_params.begin[i] = begins[3 - i];
129 op_params.size[i] = sizes[3 - i];
130 }
131 switch (input()->element_type())
132 {
133 case DataType::FLOAT32:
134 luci_interpreter_pal::Slice(op_params, getTensorShape(input()), getTensorData<float>(input()),
135 getTensorShape(output()), getTensorData<float>(output()));
136 break;
137 case DataType::U8:
138 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
139 getTensorData<uint8_t>(input()), getTensorShape(output()),
140 getTensorData<uint8_t>(output()));
141 break;
142 case DataType::S8:
143 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
144 getTensorData<int8_t>(input()), getTensorShape(output()),
145 getTensorData<int8_t>(output()));
146 break;
147 default:
148 throw std::runtime_error("Unsupported input type.");
149 }
150}
151
152} // namespace kernels
153} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Tensor * begin() const
Definition Slice.h:33
Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
Definition Slice.cpp:31
Tensor * output() const
Definition Slice.h:35
const Tensor * input() const
Definition Slice.h:32
void execute() const override
Definition Slice.cpp:97
const Tensor * size() const
Definition Slice.h:34
const luci_interpreter::RuntimeShape output_shape
Shape calculateOutputShape(const Tensor *input, const Tensor *begin, const Tensor *size)
Definition Slice.cpp:37
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size, std::vector< int > *begins, std::vector< int > *sizes)
Definition Slice.cpp:64
int32_t size[5]
Definition Slice.cpp:35
int32_t begin[5]
Definition Slice.cpp:33
This file contains utility macro.