52template <
class ARG_TYPE_1,
class ARG_TYPE_2>
class NodeFiller final
55 NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
69 template <
class COMM_NODE>
bool with_commutative_args_of(
const COMM_NODE *node);
76template <
class ARG_TYPE_1,
class ARG_TYPE_2>
77inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2>
fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
79 return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
82template <
class ARG_TYPE_1,
class ARG_TYPE_2>
83template <
class COMM_NODE>
84bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(
const COMM_NODE *node)
88 auto x =
dynamic_cast<ARG_TYPE_1 *
>(node->x());
89 auto y =
dynamic_cast<ARG_TYPE_2 *
>(node->y());
101 auto x =
dynamic_cast<ARG_TYPE_2 *
>(node->x());
102 auto y =
dynamic_cast<ARG_TYPE_1 *
>(node->y());
124 auto rank = node->rank();
126 for (axis = 0; axis < rank - 1; ++axis)
128 if (node->dim(axis).value() != 1)
131 return node->dim(axis).value() == depth;
143 if (input_shape.rank() != 4)
155 if (red_indices->rank() != 1)
157 std::set<int32_t> red_indices_set;
160 assert(red_indices->dtype() == loco::DataType::S32);
161 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
162 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
164 if (red_indices_set.size() != 2)
166 if (red_indices_set.find(1) == red_indices_set.end())
168 if (red_indices_set.find(2) == red_indices_set.end())
232class InstanceNormPattern final
238 add_as_terminal = candidate;
243 bool matched()
const {
return _matched; }
263 bool _matched =
false;
266bool InstanceNormPattern::matched()
271#define CHECK_OR_FALSE(condition) \
272 if (not(condition)) \
277 CHECK_OR_FALSE(
fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
285 uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
294 fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
296 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
298 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
320 .with_commutative_args_of(mul_as_scaled_mean));
343void fuse_instance_norm(
const InstanceNormPattern &p)
347 auto graph = p.add_as_terminal->graph();
354 uint32_t ifm_channel_depth = ifm_shape.
dim(3).
value();
356 int32_t new_shape[1] = {
static_cast<int32_t
>(ifm_channel_depth)};
358 reshape_gamma->tensor(p.const_as_gamma);
359 reshape_beta->
tensor(p.const_as_beta);
367 instance_norm->
input(p.ifm);
368 instance_norm->gamma(reshape_gamma);
369 instance_norm->beta(reshape_beta);
370 float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
371 instance_norm->epsilon(epsilon);
372 instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
384 bool changed =
false;
391 InstanceNormPattern pattern(add);
392 if (not pattern.matched())
395 fuse_instance_norm(pattern);
uint32_t value(void) const
Return the value.
Logical unit of computation.
void with(Node *into) const
const Dimension & dim(uint32_t axis) const
loco::Node * input(void) const
Class to build tensor data.
loco::Node * input(void) const
bool keep_dims(void) const
loco::Node * reduction_indices(void) const
loco::Node * tensor(void) const
#define CHECK_OR_FALSE(condition)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool shape_known(const Node *node)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
NodeShape shape_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size)
Set both TFLReshape's 2nd input as TFLConst, and newShape attribute with same value.
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
bool run(loco::Graph *g) final
Run the pass.