ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ShapeInference.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "loco/IR/Algorithm.h"
19
20#include <cassert>
21#include <memory>
22
23namespace
24{
25
26bool inputs_shape_ready(loco::Node *node)
27{
28 assert(node != nullptr);
29
30 for (uint32_t arity = 0; arity < node->arity(); ++arity)
31 {
32 if (!loco::ShapeInference::known(node->arg(arity)))
33 {
34 return false;
35 }
36 }
37 return true;
38}
39
40} // namespace
41
42//
43// Infrastructure
44//
45namespace
46{
47
48struct ShapeAnnotation : public loco::NodeAnnotation
49{
50public:
51 ShapeAnnotation(const loco::NodeShape &shape) : _shape{shape}
52 {
53 // DO NOTHING
54 }
55
56public:
57 const loco::NodeShape &shape(void) const { return _shape; }
58
59private:
60 loco::NodeShape _shape;
61};
62
63} // namespace
64
65namespace loco
66{
67
69{
70 assert(_rule->support(ShapeInferenceRule::API::V1) && "API v1 is unavailable");
71
72 bool changed = false;
73
75 {
76 if (_rule->recognize(node->dialect()))
77 {
78 loco::NodeShape shape;
79
80 if (!shape_known(node) && inputs_shape_ready(node))
81 {
82 if (_rule->infer(node, shape))
83 {
84 node->annot(std::make_unique<ShapeAnnotation>(shape));
85 changed = true;
86 }
87 }
88 }
89 }
90
91 return changed;
92}
93
94bool ShapeInference::known(const Node *node) { return node->annot<ShapeAnnotation>() != nullptr; }
95
97{
98 assert(known(node));
99 return node->annot<ShapeAnnotation>()->shape();
100}
101
102void ShapeInference::erase(Node *node) { node->annot<ShapeAnnotation>(nullptr); }
103
104} // namespace loco
const T * annot(void) const
Retrieve a stored annotation of type T.
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
virtual Node * arg(uint32_t N) const =0
Access N-th argument node.
virtual uint32_t arity(void) const =0
Return the number of arguments.
virtual const Dialect * dialect(void) const =0
Return "Dialect" identifier that this node belongs to.
std::vector< loco::Node * > postorder_traversal(const std::vector< loco::Node * > &roots)
Generate postorder traversal sequence starting from "roots".
Definition Algorithm.cpp:53
bool shape_known(const Node *node)
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Extensible Node Metadata.
Definition Node.h:39
static NodeShape get(const Node *)
static void erase(Node *)
static bool known(const Node *)
virtual bool support(const API &api) const
Check whether a given API is available or not.
virtual bool infer(const Node *, NodeShape &) const =0
Infer node's shape.
virtual bool recognize(const Dialect *) const =0
Return true if this rule recognizes a given dialect.