ONE - On-device Neural Engine
Loading...
Searching...
No Matches
AveragePool2D.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#include "
Pool2DCommon.h
"
18
19
#include "PALAveragePool2D.h"
20
21
namespace
luci_interpreter
22
{
23
24
// TODO: reduce code duplication with MaxPool2D
25
void
configure_kernel_CircleAveragePool2D
(
const
circle::Operator *
cur_op
,
26
BaseRuntimeGraph
*runtime_graph)
27
{
28
configure_kernel_CirclePool2DCommon
(
cur_op
, runtime_graph);
29
}
30
31
void
execute_kernel_CircleAveragePool2D
(
const
circle::Operator *
cur_op
,
32
BaseRuntimeGraph
*runtime_graph)
33
{
34
const
kernels::SISOKernel
siso_kernel
(
cur_op
, runtime_graph);
35
36
const
auto
input =
siso_kernel
.input();
37
const
auto
output =
siso_kernel
.output();
38
39
const
auto
*input_data = runtime_graph->
getDataByTensor
(input);
40
auto
*output_data = runtime_graph->
getDataByTensor
(output);
41
42
const
DataType
input_type = Tensor::element_type(input);
43
44
const
auto
params =
createPoolParams
(
cur_op
,
siso_kernel
);
45
46
switch
(input_type)
47
{
48
#ifndef DIS_FLOAT
49
case
DataType::FLOAT32:
50
luci_interpreter_pal::AveragePool(
51
params,
kernels::getTensorShape
(input), kernels::getTensorData<float>(input_data),
52
kernels::getTensorShape
(output), kernels::getTensorData<float>(output_data));
53
break
;
54
#endif
// DIS_FLOAT
55
#ifndef DIS_QUANT
56
case
DataType::S8:
57
case
DataType::S16:
58
luci_interpreter_pal::AveragePool(
59
params,
kernels::getTensorShape
(input), kernels::getTensorData<uint8_t>(input_data),
60
kernels::getTensorShape
(output), kernels::getTensorData<uint8_t>(output_data), input_type);
61
break
;
62
#endif
// DIS_QUANT
63
default
:
64
assert(
false
&&
"Unsupported type."
);
65
}
66
}
67
68
}
// namespace luci_interpreter
Pool2DCommon.h
luci_interpreter::RuntimeGraph
Definition
RuntimeGraph.h:33
luci_interpreter::RuntimeGraph::getDataByTensor
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
Definition
RuntimeGraph.cpp:355
luci_interpreter::kernels::SISOKernel
Definition
SISOKernel.h:29
luci_interpreter::kernels::getTensorShape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition
Utils.h:194
luci_interpreter
Definition
BuddyMemoryManager.h:22
luci_interpreter::configure_kernel_CircleAveragePool2D
void configure_kernel_CircleAveragePool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition
AveragePool2D.cpp:25
luci_interpreter::createPoolParams
luci_interpreter_pal::PoolParams createPoolParams(const circle::Operator *cur_op, const kernels::SISOKernel &siso_kernel)
Definition
Pool2DCommon.h:56
luci_interpreter::configure_kernel_CirclePool2DCommon
void configure_kernel_CirclePool2DCommon(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition
Pool2DCommon.h:27
luci_interpreter::execute_kernel_CircleAveragePool2D
void execute_kernel_CircleAveragePool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition
AveragePool2D.cpp:31
luci_interpreter::DataType
DataType
"scalar" value type
Definition
DataType.h:32
luci::must_cast
T must_cast(loco::Node *node)
Definition
CircleNodeDecl.h:95
onert-micro
luci-interpreter
src
kernels
AveragePool2D.cpp
Generated by
1.9.8