ONE - On-device Neural Engine
Loading...
Searching...
No Matches
CircleFullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
17
18#include "CircleCloneNode.h"
20
21#include "Check.h"
22
23namespace luci
24{
25
27{
29 return nullptr;
31 return nullptr;
32
33 auto *cloned = _graph->nodes()->create<luci::CircleFullyConnected>();
34 if (cloned != nullptr)
35 {
37 cloned->weights_format(node->weights_format());
38 cloned->keep_num_dims(node->keep_num_dims());
39 }
40 return cloned;
41}
42
43namespace sinf
44{
45
47{
48 auto input_shape = circle_shape(loco::must_cast<CircleNode *>(node->input()));
49 auto weights_shape = circle_shape(loco::must_cast<CircleNode *>(node->weights()));
50
51 loco::TensorShape out_shape;
52
53 // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
54 // Until they are all fixed, disable following assert.
55 // TODO Enable following assert after related fixes are applied
56 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
57 // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
58 // "Input rank of FullyConnected should be 2 or 3");
59
60 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
61 LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
62 LUCI_ASSERT(weights_shape.dim(0).known() && weights_shape.dim(1).known(),
63 "Weights of FullyConnected should be known")
64 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
65
66 /*
67 * **Pre-conditions:**
68 * input_shape.rank() <= 4
69 * * remark: TFLite allows <=3 ranks, but there are rank 4 input recipes in ONE
70 * weights_shape.rank() == 2 and all dimensions are known.
71 * When runtime(input_shape[-1] and weights_shape[-1] are both known), it should be same value.
72 *
73 * **Shape Inference Rule:**
74 * **Input Shape:**
75 * input_shape : (A, B, C, D)
76 * weights_shape : (E, F)
77 * A, B, C, D are "positive numbers" or "unknown".
78 * E, F are always "positive numbers".
79 *
80 * **Output Shape:**
81 * If keep_dims = True : (A, B, C, E)
82 * If keep_dims = False : (G, E)
83 * * G = unknown (if any of A, B, or C is unknown.)
84 * * G = A * B * C (otherwise.)
85 */
86
87 if (node->keep_num_dims())
88 {
89 out_shape.rank(input_shape.rank());
90 for (uint32_t i = 0; i < input_shape.rank(); ++i)
91 out_shape.dim(i) = input_shape.dim(i);
92 out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
93 }
94 else
95 {
96 bool is_dynamic_shape = false;
97
98 for (uint32_t i = 0; i < input_shape.rank() - 1; i++)
99 {
100 if (not input_shape.dim(i).known())
101 {
102 is_dynamic_shape = true;
103 break;
104 }
105 }
106
107 uint32_t batch_size = 1;
108
109 for (uint32_t i = 0; i < input_shape.rank() - 1; i++)
110 {
111 batch_size *= input_shape.dim(i).value();
112 }
113
114 out_shape.rank(2);
115 if (is_dynamic_shape)
116 out_shape.dim(0).unset();
117 else
118 out_shape.dim(0) = batch_size;
119 out_shape.dim(1) = weights_shape.dim(0);
120 }
121
122 return out_shape;
123}
124
125} // namespace sinf
126
127} // namespace luci
void unset(void)
Definition Dimension.h:59
uint32_t value(void) const
Return the value.
Definition Dimension.h:51
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
FULLY_CONNECTED in Circle.
loco::Node * weights(void) const
WeightsFormat weights_format(void) const
loco::Node * input(void) const
loco::TensorShape visit(const luci::CircleNode *node) final
Default fallback.
#define LUCI_ASSERT(condition, msg)
Definition Check.h:26
loco::TensorShape circle_shape(const luci::CircleNode *node)