ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Reshape.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Builders.h"
19#include "Utils.h"
20
21#include <cassert>
22#include <cstring>
23
24namespace luci_interpreter
25{
26
27void configure_kernel_CircleReshape(const circle::Operator *, BaseRuntimeGraph *)
28{
29 // Do nothing
30}
31
32// TODO: reduce code duplication with ExpandDims
33void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
34{
35 const auto input_index = cur_op->inputs()->operator[](0);
36 const auto shape_index = cur_op->inputs()->operator[](1);
37 const auto output_index = cur_op->outputs()->operator[](0);
38
39 assert(input_index != -1);
40 assert(shape_index != -1);
41 assert(output_index != -1);
42
43 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
44 const auto shape = runtime_graph->getCircleTensorByIndex(shape_index);
45 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
46 bool is_inplace = runtime_graph->is_inplace_op(cur_op);
47 if (is_inplace)
48 {
49 runtime_graph->makeInplaceOperation(input, output);
50 return;
51 }
52
53 const auto input_data = runtime_graph->getDataByTensor(input);
54 auto shape_data = runtime_graph->getConstDataByTensor(shape);
55 auto output_data = runtime_graph->getDataByTensor(output);
56
57 assert(input_data != nullptr);
58 assert(output_data != nullptr);
59
60 int32_t data_size = Tensor::num_elements(output) * getDataTypeSize(Tensor::element_type(output));
61
62#ifndef DIS_DYN_SHAPES
63 if (shape_data == nullptr)
64 {
65 shape_data = runtime_graph->getDataByTensor(shape);
66 assert(shape_data != nullptr);
67
68 assert(Tensor::element_type(shape) == DataType::S32);
69
70 const int32_t *shape_data_int = kernels::getTensorData<int32_t>(shape_data);
71 const auto num_dims = Tensor::num_dims(output);
72
73 luci_interpreter::RuntimeShape dynamic_shape(num_dims);
74 data_size = 1;
75 for (int i = 0; i < num_dims; ++i)
76 {
77 dynamic_shape.setDim(i, shape_data_int[i]);
78 data_size *= shape_data_int[i];
79 }
80 data_size *= size(Tensor::element_type(output));
81
82 runtime_graph->addDynamicShapeTensor(output, std::move(dynamic_shape));
83
84 if (data_size == 0)
85 {
86 runtime_graph->resetTensorData(nullptr, output);
87 return;
88 }
89
90 auto new_output_data = new uint8_t[data_size];
91 output_data = new_output_data;
92 runtime_graph->resetTensorData(new_output_data, output);
93 }
94#else
95 assert(shape_data != nullptr);
96#endif // DIS_DYN_SHAPES
97
98 std::memcpy(output_data, input_data, data_size);
99}
100
101} // namespace luci_interpreter
void makeInplaceOperation(const circle::Tensor *src_tensor, const circle::Tensor *dst_tensor)
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * getCircleTensorByIndex(int32_t index)
void addDynamicShapeTensor(const circle::Tensor *tensor, luci_interpreter::RuntimeShape &&shapes)
void resetTensorData(uint8_t *new_data, const circle::Tensor *tensor)
bool is_inplace_op(const circle::Operator *op)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
void setDim(int i, int32_t val)
Definition Tensor.h:114
void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Reshape.cpp:33
void configure_kernel_CircleReshape(const circle::Operator *, BaseRuntimeGraph *)
Definition Reshape.cpp:27
size_t getDataTypeSize(DataType data_type)
Definition DataType.h:33
int32_t size[5]
Definition Slice.cpp:35
This file contains utility macro.