ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SpaceToBatchND.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
19#include "kernels/Utils.h"
20
21#include "PALSpaceToBatchND.h"
22
23#include <stdexcept>
24
25namespace luci_interpreter
26{
27namespace kernels
28{
29namespace
30{
31
32const int kInputMinDimensionNum = 3;
33const int kInputMaxDimensionNum = 4;
34
35} // namespace
36
37SpaceToBatchND::SpaceToBatchND(const Tensor *input, const Tensor *block_shape,
38 const Tensor *paddings, Tensor *output)
39 : Kernel({input, block_shape, paddings}, {output})
40{
41}
42
44{
45 const auto *block_shape_data = block_shape()->data<int32_t>();
46 const auto *paddings_data = paddings()->data<int32_t>();
47 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= kInputMinDimensionNum);
48 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() <= kInputMaxDimensionNum);
49 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
50
51 int spatial_dims_num = input()->shape().num_dims() - 2;
52
53 LUCI_INTERPRETER_CHECK(block_shape()->shape().num_dims() == 1);
54 LUCI_INTERPRETER_CHECK(block_shape()->shape().dim(0) == spatial_dims_num);
55
56 LUCI_INTERPRETER_CHECK(paddings()->shape().num_dims() == 2);
57 LUCI_INTERPRETER_CHECK(paddings()->shape().dim(0) == spatial_dims_num);
58 LUCI_INTERPRETER_CHECK(paddings()->shape().dim(1) == 2);
59
60 Shape output_shape = Shape(input()->shape().num_dims());
61 int output_batch_size = input()->shape().dim(0);
62 for (int i = 0; i < spatial_dims_num; ++i)
63 {
64 int final_dim_size =
65 (input()->shape().dim(i + 1) + paddings_data[i * 2] + paddings_data[i * 2 + 1]);
66 LUCI_INTERPRETER_CHECK(final_dim_size % block_shape_data[i] == 0);
67 output_shape.dim(i + 1) = final_dim_size / block_shape_data[i];
68 output_batch_size = output_batch_size * block_shape_data[i];
69 }
70 output_shape.dim(0) = output_batch_size;
71 output_shape.dim(input()->shape().num_dims() - 1) =
72 input()->shape().dim(input()->shape().num_dims() - 1);
74}
75
77{
78 switch (input()->element_type())
79 {
80 tflite::SpaceToBatchParams op_params;
81 case DataType::FLOAT32:
82 op_params.output_offset = 0;
83 luci_interpreter_pal::SpaceToBatchND(
84 op_params, getTensorShape(input()), getTensorData<float>(input()),
85 getTensorShape(block_shape()), getTensorData<int32_t>(block_shape()),
86 getTensorShape(paddings()), getTensorData<int32_t>(paddings()), getTensorShape(output()),
87 getTensorData<float>(output()));
88 break;
89 case DataType::U8:
90 op_params.output_offset = output()->zero_point();
91 luci_interpreter_pal::SpaceToBatchND(
92 op_params, getTensorShape(input()), getTensorData<uint8_t>(input()),
93 getTensorShape(block_shape()), getTensorData<int32_t>(block_shape()),
94 getTensorShape(paddings()), getTensorData<int32_t>(paddings()), getTensorShape(output()),
95 getTensorData<uint8_t>(output()));
96 break;
97 default:
98 throw std::runtime_error("luci-intp ShapeToBatchND Unsupported type.");
99 }
100}
101
102} // namespace kernels
103} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const T * data() const
Definition Tensor.h:127
int32_t zero_point() const
Definition Tensor.h:115
SpaceToBatchND(const Tensor *input, const Tensor *block_shape, const Tensor *paddings, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194