ONE - On-device Neural Engine
Loading...
Searching...
No Matches
AveragePool2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Pool2DCommon.h"
18
19#include "PALAveragePool2D.h"
20
21namespace luci_interpreter
22{
23
24// TODO: reduce code duplication with MaxPool2D
25void configure_kernel_CircleAveragePool2D(const circle::Operator *cur_op,
26 BaseRuntimeGraph *runtime_graph)
27{
28 configure_kernel_CirclePool2DCommon(cur_op, runtime_graph);
29}
30
31void execute_kernel_CircleAveragePool2D(const circle::Operator *cur_op,
32 BaseRuntimeGraph *runtime_graph)
33{
34 const kernels::SISOKernel siso_kernel(cur_op, runtime_graph);
35
36 const auto input = siso_kernel.input();
37 const auto output = siso_kernel.output();
38
39 const auto *input_data = runtime_graph->getDataByTensor(input);
40 auto *output_data = runtime_graph->getDataByTensor(output);
41
42 const DataType input_type = Tensor::element_type(input);
43
44 const auto params = createPoolParams(cur_op, siso_kernel);
45
46 switch (input_type)
47 {
48#ifndef DIS_FLOAT
49 case DataType::FLOAT32:
50 luci_interpreter_pal::AveragePool(
51 params, kernels::getTensorShape(input), kernels::getTensorData<float>(input_data),
52 kernels::getTensorShape(output), kernels::getTensorData<float>(output_data));
53 break;
54#endif // DIS_FLOAT
55#ifndef DIS_QUANT
56 case DataType::S8:
57 case DataType::S16:
58 luci_interpreter_pal::AveragePool(
59 params, kernels::getTensorShape(input), kernels::getTensorData<uint8_t>(input_data),
60 kernels::getTensorShape(output), kernels::getTensorData<uint8_t>(output_data), input_type);
61 break;
62#endif // DIS_QUANT
63 default:
64 assert(false && "Unsupported type.");
65 }
66}
67
68} // namespace luci_interpreter
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
Definition SISOKernel.h:47
const circle::Tensor * input() const
Definition SISOKernel.h:46
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void configure_kernel_CircleAveragePool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
luci_interpreter_pal::PoolParams createPoolParams(const circle::Operator *cur_op, const kernels::SISOKernel &siso_kernel)
void configure_kernel_CirclePool2DCommon(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void execute_kernel_CircleAveragePool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
DataType
"scalar" value type
Definition DataType.h:32