ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CircleFullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
17
18#include "CircleCloneNode.h"
20
21#include "Check.h"
22
23namespace luci
24{
25
27{
29 return nullptr;
31 return nullptr;
32
33 auto *cloned = _graph->nodes()->create<luci::CircleFullyConnected>();
34 {
36 cloned->weights_format(node->weights_format());
37 cloned->keep_num_dims(node->keep_num_dims());
38 }
39 return cloned;
40}
41
42namespace sinf
43{
44
46{
47 auto input_shape = circle_shape(loco::must_cast<CircleNode *>(node->input()));
48 auto weights_shape = circle_shape(loco::must_cast<CircleNode *>(node->weights()));
49
50 loco::TensorShape out_shape;
51
52 // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
53 // Until they are all fixed, disable following assert.
54 // TODO Enable following assert after related fixes are applied
55 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
56 // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
57 // "Input rank of FullyConnected should be 2 or 3");
58
59 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
60 LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
61 LUCI_ASSERT(weights_shape.dim(0).known() && weights_shape.dim(1).known(),
62 "Weights of FullyConnected should be known")
63 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
64
65 /*
66 * **Pre-conditions:**
67 * input_shape.rank() <= 4
68 * * remark: TFLite allows <=3 ranks, but there are rank 4 input recipes in ONE
69 * weights_shape.rank() == 2 and all dimensions are known.
70 * When runtime(input_shape[-1] and weights_shape[-1] are both known), it should be same value.
71 *
72 * **Shape Inference Rule:**
73 * **Input Shape:**
74 * input_shape : (A, B, C, D)
75 * weights_shape : (E, F)
76 * A, B, C, D are "positive numbers" or "unknown".
77 * E, F are always "positive numbers".
78 *
79 * **Output Shape:**
80 * If keep_dims = True : (A, B, C, E)
81 * If keep_dims = False : (G, E)
82 * * G = unknown (if any of A, B, or C is unknown.)
83 * * G = A * B * C (otherwise.)
84 */
85
86 if (node->keep_num_dims())
87 {
88 out_shape.rank(input_shape.rank());
89 for (uint32_t i = 0; i < input_shape.rank(); ++i)
90 out_shape.dim(i) = input_shape.dim(i);
91 out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
92 }
93 else
94 {
95 bool is_dynamic_shape = false;
96
97 for (uint32_t i = 0; i < input_shape.rank() - 1; i++)
98 {
99 if (not input_shape.dim(i).known())
100 {
101 is_dynamic_shape = true;
102 break;
103 }
104 }
105
106 uint32_t batch_size = 1;
107
108 for (uint32_t i = 0; i < input_shape.rank() - 1; i++)
109 {
110 batch_size *= input_shape.dim(i).value();
111 }
112
113 out_shape.rank(2);
114 if (is_dynamic_shape)
115 out_shape.dim(0).unset();
116 else
117 out_shape.dim(0) = batch_size;
118 out_shape.dim(1) = weights_shape.dim(0);
119 }
120
121 return out_shape;
122}
123
124} // namespace sinf
125} // namespace luci
void unset(void)
Definition Dimension.h:59
uint32_t value(void) const
Return the value.
Definition Dimension.h:51
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
FULLY_CONNECTED in Circle.
loco::Node * weights(void) const
WeightsFormat weights_format(void) const
loco::Node * input(void) const
loco::TensorShape visit(const luci::CircleNode *node) final
Default fallback.
#define LUCI_ASSERT(condition, msg)
Definition Check.h:26
loco::TensorShape circle_shape(const luci::CircleNode *node)