ONE - On-device Neural Engine
Loading...
Searching...
No Matches
MaxPool2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Pool2DCommon.h"
18
19#include "PALMaxPool2D.h"
20
21namespace luci_interpreter
22{
23void configure_kernel_CircleMaxPool2D(const circle::Operator *cur_op,
24 BaseRuntimeGraph *runtime_graph)
25{
26 configure_kernel_CirclePool2DCommon(cur_op, runtime_graph);
27}
28
29void execute_kernel_CircleMaxPool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
30{
31 const kernels::SISOKernel siso_kernel(cur_op, runtime_graph);
32
33 const auto input = siso_kernel.input();
34 const auto output = siso_kernel.output();
35
36 const auto *input_data = runtime_graph->getDataByTensor(input);
37 auto *output_data = runtime_graph->getDataByTensor(output);
38
39 const DataType input_type = Tensor::element_type(input);
40
41 const auto params = createPoolParams(cur_op, siso_kernel);
42
43 switch (input_type)
44 {
45#ifndef DIS_FLOAT
46 case DataType::FLOAT32:
48 params, kernels::getTensorShape(input), kernels::getTensorData<float>(input_data),
49 kernels::getTensorShape(output), kernels::getTensorData<float>(output_data));
50 break;
51#endif // DIS_FLOAT
52#ifndef DIS_QUANT
53 case DataType::S8:
54 case DataType::S16:
56 params, kernels::getTensorShape(input), kernels::getTensorData<uint8_t>(input_data),
57 kernels::getTensorShape(output), kernels::getTensorData<uint8_t>(output_data), input_type);
58 break;
59#endif // DIS_QUANT
60 default:
61 assert(false && "Unsupported type.");
62 }
63}
64
65} // namespace luci_interpreter
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
Definition SISOKernel.h:47
const circle::Tensor * input() const
Definition SISOKernel.h:46
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void MaxPool(const PoolParams &params, const luci_interpreter::RuntimeShape &input_shape, const uint8_t *input_data, const luci_interpreter::RuntimeShape &output_shape, uint8_t *output_data, luci_interpreter::DataType data_type)
luci_interpreter_pal::PoolParams createPoolParams(const circle::Operator *cur_op, const kernels::SISOKernel &siso_kernel)
void configure_kernel_CirclePool2DCommon(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleMaxPool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition MaxPool2D.cpp:23
DataType
"scalar" value type
Definition DataType.h:32
void execute_kernel_CircleMaxPool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition MaxPool2D.cpp:29