ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Pack.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Builders.h"
19#include "Utils.h"
20
21#include <cassert>
22
23namespace luci_interpreter
24{
25namespace
26{
27
28template <typename T>
29void packImpl(const circle::Tensor *input0, const circle::Tensor *output,
30 const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
31 uint8_t *output_data_raw)
32{
33 const auto *options = cur_op->builtin_options_as_PackOptions();
34
35 const int values_count = options->values_count();
36 int axis = options->axis();
37 const int dimensions = Tensor::num_dims(output);
38
39 const auto input_dims = wrap(input0->shape());
40 const auto output_dims = wrap(output->shape());
41
42 if (axis < 0)
43 {
44 axis += dimensions;
45 }
46
47 int outer_size = 1;
48 for (int i = 0; i < axis; ++i)
49 outer_size *= output_dims[i];
50
51 int copy_size = 1;
52 for (int i = axis + 1; i < dimensions; ++i)
53 copy_size *= output_dims[i];
54
55 int input_size = 1;
56 for (int i = 0; i < input_dims.size(); ++i)
57 input_size *= input_dims[i];
58
59 assert(input_size == copy_size * outer_size);
60
61 T *output_data = kernels::getTensorData<T>(output_data_raw);
62 assert(output_data != nullptr);
63
64 for (int i = 0; i < values_count; ++i)
65 {
66 const auto input_index = cur_op->inputs()->operator[](i);
67 assert(input_index != -1);
68 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
69
70 auto input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
71 assert(input_data != nullptr);
72 for (int k = 0; k < outer_size; ++k)
73 {
74 const T *input_ptr = input_data + copy_size * k;
75 int loc = k * values_count * copy_size + i * copy_size;
76 T *output_ptr = output_data + loc;
77 for (int j = 0; j < copy_size; ++j)
78 output_ptr[j] = input_ptr[j];
79 }
80 }
81}
82
83} // namespace
84
85void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
86{
87 // Do nothing
88}
89
90void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
91{
92 const auto input_index = cur_op->inputs()->operator[](0);
93 const auto output_index = cur_op->outputs()->operator[](0);
94 assert(output_index != -1);
95 assert(input_index != -1);
96 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
97 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
98
99 auto output_data = runtime_graph->getDataByTensor(output);
100 assert(output_data != nullptr);
101
102 switch (Tensor::element_type(output))
103 {
104#ifndef DIS_FLOAT
105 case DataType::FLOAT32:
106 packImpl<float>(input, output, cur_op, runtime_graph, output_data);
107 break;
108#endif // DIS_FLOAT
109#ifndef DIS_QUANT
110 case DataType::S8:
111 packImpl<int8_t>(input, output, cur_op, runtime_graph, output_data);
112 break;
113 case DataType::U8:
114 packImpl<uint8_t>(input, output, cur_op, runtime_graph, output_data);
115 break;
116#endif // DIS_QUANT
117 case DataType::S32:
118 packImpl<int32_t>(input, output, cur_op, runtime_graph, output_data);
119 break;
120 case DataType::S64:
121 packImpl<int64_t>(input, output, cur_op, runtime_graph, output_data);
122 break;
123 default:
124 assert(false && "Unsupported types");
125 }
126}
127
128} // namespace luci_interpreter
const circle::Tensor * getCircleTensorByIndex(int32_t index)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
list input_data
Definition infer.py:29
void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Pack.cpp:90
RuntimeGraph BaseRuntimeGraph
void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
Definition Pack.cpp:85
VectorWrapper< T > wrap(const flatbuffers::Vector< T > *vec)
This file contains utility macro.