ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TransposeConv.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#include "
Builders.h
"
18
19
#include "
kernels/TransposeConv.h
"
20
21
namespace
luci_interpreter
22
{
23
24
std::unique_ptr<Kernel>
build_kernel_CircleTransposeConv
(
const
luci::CircleNode
*
circle_node
,
25
KernelBuilderHelper
&helper)
26
{
27
const
auto
*node =
loco::must_cast<const luci::CircleTransposeConv *>
(
circle_node
);
28
assert(node->arity() == 4);
29
30
const
Tensor
*input_sizes = helper.
getInputTensor
(node->inputSizes());
31
const
Tensor
*filter = helper.
getInputTensor
(node->filter());
32
const
Tensor
*out_backprop = helper.
getInputTensor
(node->outBackprop());
33
const
Tensor
*bias = helper.
getOptionalInputTensor
(node->bias());
34
35
Tensor
*output = helper.
getOutputTensor
(node);
36
37
DataType
scratch_data_type
=
38
helper.
getInputTensor
(node)->
element_type
() == DataType::S16 ? DataType::S64 : DataType::S32;
39
40
auto
scratch_tensor
=
41
std::make_unique<Tensor>(
scratch_data_type
,
Shape
({}),
AffineQuantization
{},
""
);
42
scratch_tensor
->set_observable(
false
);
43
scratch_tensor
->set_data_buffer(
nullptr
);
44
Tensor
*
tmp
= helper.
getRuntimeGraph
(node->graph())->
addTensor
(std::move(
scratch_tensor
));
45
46
TransposeConvParams
params{};
47
params.
padding
= node->padding();
48
params.stride_height = node->stride()->h();
49
params.stride_width = node->stride()->w();
50
params.activation = node->fusedActivationFunction();
51
52
// TODO support activation
53
if
(params.activation !=
luci::FusedActFunc::NONE
)
54
{
55
throw
std::runtime_error(
"Unsupported activation of TransposeConv"
);
56
}
57
58
return
std::make_unique<kernels::TransposeConv>(input_sizes, filter, out_backprop, bias, output,
59
tmp
, params);
60
}
61
62
}
// namespace luci_interpreter
luci_interpreter::KernelBuilderHelper
Definition
KernelBuilderHelper.h:33
luci_interpreter::KernelBuilderHelper::getOutputTensor
Tensor * getOutputTensor(const loco::Node *node) const
Definition
KernelBuilderHelper.cpp:40
luci_interpreter::KernelBuilderHelper::getOptionalInputTensor
const Tensor * getOptionalInputTensor(const loco::Node *node) const
Definition
KernelBuilderHelper.cpp:31
luci_interpreter::KernelBuilderHelper::getRuntimeGraph
RuntimeGraph * getRuntimeGraph(const loco::Graph *graph) const
Definition
KernelBuilderHelper.cpp:57
luci_interpreter::KernelBuilderHelper::getInputTensor
const Tensor * getInputTensor(const loco::Node *node) const
Definition
KernelBuilderHelper.cpp:24
luci_interpreter::RuntimeGraph::addTensor
Tensor * addTensor(std::unique_ptr< Tensor > &&tensor)
Definition
RuntimeGraph.cpp:118
luci_interpreter::Shape
Definition
Tensor.h:33
luci_interpreter::Tensor
Definition
Tensor.h:101
luci_interpreter::Tensor::element_type
DataType element_type() const
Definition
Tensor.h:105
TransposeConv.h
loco::must_cast
T must_cast(FeatureEncoder *node)
A helper dynamic_cast that throws when failed.
Definition
FeatureCodec.h:80
luci_interpreter
Definition
BuddyMemoryManager.h:22
luci_interpreter::build_kernel_CircleTransposeConv
std::unique_ptr< Kernel > build_kernel_CircleTransposeConv(const luci::CircleNode *circle_node, KernelBuilderHelper &helper)
Definition
TransposeConv.cpp:24
luci_interpreter::DataType
DataType
"scalar" value type
Definition
DataType.h:32
luci::must_cast
T must_cast(loco::Node *node)
Definition
CircleNodeDecl.h:95
luci::FusedActFunc::NONE
@ NONE
Builders.h
luci::CircleNode
Definition
CircleNodeDecl.h:42
luci_interpreter::AffineQuantization
Definition
Tensor.h:94
luci_interpreter::TransposeConvParams
Definition
KernelParams.h:245
luci_interpreter::TransposeConvParams::padding
Padding padding
Definition
KernelParams.h:246
compiler
luci-interpreter
src
loader
nodes
TransposeConv.cpp
Generated by
1.9.8