ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
fme_apply Namespace Reference

Data Structures

struct  EqualizePattern
 
class  FMEqualizer
 
class  FusePostScalePass
 Pass to fuse CircleCustom(PostScale) to succeeding Ops. More...
 
class  FusePreScalePass
 Pass to fuse CircleCustom(PreScale) to preceding Ops. More...
 
class  InsertScaleShift
 Class to insert scale/shift virtual Ops to loco::Graph. More...
 
class  ProgressReporter
 

Functions

void check_patterns_valid (loco::Graph *g, const std::vector< EqualizePattern > &patterns)
 
std::vector< EqualizePatternread (const std::string &filename)
 
std::string random_str (uint32_t len)
 
luci::CircleCustomto_scale (loco::Node *node)
 
void copy_shape (luci::CircleNode *from, luci::CircleNode *to)
 
loco::Nodeget_input (luci::CircleNode *node)
 
void set_input (luci::CircleNode *node, luci::CircleCustom *input)
 
luci::CircleNodefind_arg_with_name (const luci::CircleNode *node, const std::string &name, const uint32_t &depth)
 

Function Documentation

◆ check_patterns_valid()

void fme_apply::check_patterns_valid ( loco::Graph g,
const std::vector< EqualizePattern > &  patterns 
)

It checks if given patterns are valid as follows.

  • "scale" is empty.
  • "front" and "back" of the patterns are in the graph.

Definition at line 37 of file EqualizePatternCheck.cpp.

38{
39 // Create a map to find node by its name
40 std::map<std::string, const luci::CircleNode *> node_by_name;
41 {
42 for (auto node : loco::active_nodes(loco::output_nodes(g)))
43 {
44 auto cnode = loco::must_cast<luci::CircleNode *>(node);
45 node_by_name[cnode->name()] = cnode;
46 }
47 }
48
49 for (const auto &p : patterns)
50 {
51 // "scale" is empty.
52 // "scale" is calculated in the runtime.
53 if (not p.scale.empty())
54 {
55 throw std::runtime_error{"'scale' shouldn't exist."};
56 }
57
58 // "front" and "back" of the patterns are in the graph.
59 if (node_by_name.find(p.front) == node_by_name.end() or
60 node_by_name.find(p.back) == node_by_name.end())
61 {
62 throw std::runtime_error{"Given front or back don't exist in the graph."};
63 }
64 }
65}
Configuration p

References loco::active_nodes(), loco::output_nodes(), and p.

Referenced by fme_apply::FMEqualizer::equalize().

◆ copy_shape()

void fme_apply::copy_shape ( luci::CircleNode from,
luci::CircleNode to 
)

Definition at line 22 of file Support.Misc.cpp.

23{
24 if (not from)
25 throw std::invalid_argument("from");
26
27 if (not to)
28 throw std::invalid_argument("to");
29
30 to->rank(from->rank());
31 for (uint32_t i = 0; i < from->rank(); ++i)
32 {
33 to->dim(i) = from->dim(i);
34 }
35}

◆ find_arg_with_name()

luci::CircleNode * fme_apply::find_arg_with_name ( const luci::CircleNode node,
const std::string &  name,
const uint32_t &  depth 
)

It returns one of given node's arguments whose name is "name".

According to the depth, it finds from more preceded nodes.

Definition at line 174 of file Support.Misc.cpp.

176{
177 if (depth == 0)
178 return nullptr;
179
180 const auto arity = node->arity();
181 for (uint32_t idx = 0; idx < arity; idx++)
182 {
183 auto front_node = loco::must_cast<luci::CircleNode *>(node->arg(idx));
184 if (front_node->name() == name)
185 return front_node;
186 front_node = find_arg_with_name(front_node, name, depth - 1);
187 if (front_node)
188 return front_node;
189 }
190 return nullptr;
191}
virtual Node * arg(uint32_t N) const =0
Access N-th argument node.
virtual uint32_t arity(void) const =0
Return the number of arguments.
luci::CircleNode * find_arg_with_name(const luci::CircleNode *node, const std::string &name, const uint32_t &depth)

References loco::Node::arg(), loco::Node::arity(), and find_arg_with_name().

Referenced by find_arg_with_name().

◆ get_input()

loco::Node * fme_apply::get_input ( luci::CircleNode node)

It returns given node's input.

Definition at line 40 of file Support.Misc.cpp.

41{
42 switch (node->opcode())
43 {
44 case luci::CircleOpcode::CONV_2D:
45 {
46 auto conv = loco::must_cast<luci::CircleConv2D *>(node);
47 return conv->input();
48 }
49 case luci::CircleOpcode::DEPTHWISE_CONV_2D:
50 {
51 auto dconv = loco::must_cast<luci::CircleDepthwiseConv2D *>(node);
52 return dconv->input();
53 }
54 case luci::CircleOpcode::FULLY_CONNECTED:
55 {
56 auto fc = loco::must_cast<luci::CircleFullyConnected *>(node);
57 return fc->input();
58 }
59 case luci::CircleOpcode::GELU:
60 {
61 auto gelu = loco::must_cast<luci::CircleGelu *>(node);
62 return gelu->features();
63 }
64 case luci::CircleOpcode::LEAKY_RELU:
65 {
66 auto relu = loco::must_cast<luci::CircleLeakyRelu *>(node);
67 return relu->features();
68 }
69 case luci::CircleOpcode::MAX_POOL_2D:
70 {
71 auto maxpool = loco::must_cast<luci::CircleMaxPool2D *>(node);
72 return maxpool->value();
73 }
74 case luci::CircleOpcode::PAD:
75 {
76 auto pad = loco::must_cast<luci::CirclePad *>(node);
77 return pad->input();
78 }
79 case luci::CircleOpcode::RELU:
80 {
81 auto relu = loco::must_cast<luci::CircleLeakyRelu *>(node);
82 return relu->features();
83 }
84 case luci::CircleOpcode::TRANSPOSE_CONV:
85 {
86 auto tconv = loco::must_cast<luci::CircleTransposeConv *>(node);
87 return tconv->outBackprop();
88 }
89 default:
90 {
91 throw std::runtime_error("(get_input) NYI operator: " + node->name());
92 }
93 }
94}
NodeName name(void) const
virtual CircleOpcode opcode(void) const =0

References luci::CircleNode::name(), and luci::CircleNode::opcode().

◆ random_str()

std::string fme_apply::random_str ( uint32_t  len)

Definition at line 25 of file RandomString.cpp.

26{
27 static const char cand[] = "0123456789"
28 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
29 "abcdefghijklmnopqrstuvwxyz";
30
31 std::string res;
32 res.reserve(len);
33
34 for (uint32_t i = 0; i < len; ++i)
35 {
36 res += cand[std::rand() % (sizeof(cand) - 1)];
37 }
38
39 return res;
40}

◆ read()

std::vector< EqualizePattern > fme_apply::read ( const std::string &  filename)

Definition at line 44 of file EqualizePatternRead.cpp.

45{
46 Json::Value root;
47 std::ifstream ifs(filename);
48
49 // Failed to open cfg file
50 if (not ifs.is_open())
51 throw std::runtime_error("Cannot open config file. " + filename);
52
53 Json::CharReaderBuilder builder;
54 JSONCPP_STRING errs;
55
56 // Failed to parse
57 if (not parseFromStream(builder, ifs, &root, &errs))
58 throw std::runtime_error("Cannot parse config file (json format). " + errs);
59
60 std::vector<EqualizePattern> res;
61
62 for (auto &eq_pattern : root)
63 {
64 auto get_string = [&](const std::string &val) {
65 if (not eq_pattern.isMember(val))
66 throw std::runtime_error(val + " is missing in " + filename);
67 if (not eq_pattern[val].isString())
68 throw std::runtime_error(val + " is not string");
69
70 return eq_pattern[val].asString();
71 };
72
73 auto get_fp32_array = [&](const std::string &val) {
74 if (not eq_pattern.isMember(val))
75 throw std::runtime_error(val + " is missing in " + filename);
76 auto arr = eq_pattern[val];
77 if (not arr.isArray())
78 throw std::runtime_error(val + " is not array");
79
80 std::vector<float> res;
81 for (auto &elem : arr)
82 {
83 if (not elem.isNumeric())
84 throw std::runtime_error(val + "'s element is not fp32");
85
86 res.emplace_back(elem.asFloat());
87 }
88
89 return res;
90 };
91
92 auto front = get_string("front");
93 auto back = get_string("back");
94 auto type = get_string("type");
95
97 {
98 p.front = front;
99 p.back = back;
100 p.type = eq_type(type);
101 switch (p.type)
102 {
103 case EqualizePattern::Type::ScaleOnly:
104 p.act_scale = get_fp32_array("act_scale");
105 break;
106 default:
107 throw std::runtime_error("Unsupported EqualizePattern type");
108 }
109 }
110 res.emplace_back(p);
111 }
112
113 return res;
114}
type
Definition infer.py:18
std::string get_string(void)
get_string will return string of major.minor.patch (without build)
Definition version.cpp:44

References fme_apply::EqualizePattern::front, p, and fme_apply::EqualizePattern::ScaleOnly.

Referenced by entry().

◆ set_input()

void fme_apply::set_input ( luci::CircleNode node,
luci::CircleCustom input 
)

It sets given 'input' to node's input.

Definition at line 99 of file Support.Misc.cpp.

100{
101 if (input == nullptr)
102 {
103 throw std::runtime_error("Invalid input.");
104 }
105
106 switch (node->opcode())
107 {
108 case luci::CircleOpcode::CONV_2D:
109 {
110 auto conv = loco::must_cast<luci::CircleConv2D *>(node);
111 conv->input(input);
112 break;
113 }
114 case luci::CircleOpcode::DEPTHWISE_CONV_2D:
115 {
116 auto dconv = loco::must_cast<luci::CircleDepthwiseConv2D *>(node);
117 dconv->input(input);
118 break;
119 }
120 case luci::CircleOpcode::FULLY_CONNECTED:
121 {
122 auto fc = loco::must_cast<luci::CircleFullyConnected *>(node);
123 fc->input(input);
124 break;
125 }
126 case luci::CircleOpcode::GELU:
127 {
128 auto gelu = loco::must_cast<luci::CircleGelu *>(node);
129 gelu->features(input);
130 break;
131 }
132 case luci::CircleOpcode::LEAKY_RELU:
133 {
134 auto relu = loco::must_cast<luci::CircleLeakyRelu *>(node);
135 relu->features(input);
136 break;
137 }
138 case luci::CircleOpcode::MAX_POOL_2D:
139 {
140 auto maxpool = loco::must_cast<luci::CircleMaxPool2D *>(node);
141 maxpool->value(input);
142 break;
143 }
144 case luci::CircleOpcode::PAD:
145 {
146 auto pad = loco::must_cast<luci::CirclePad *>(node);
147 pad->input(input);
148 break;
149 }
150 case luci::CircleOpcode::RELU:
151 {
152 auto relu = loco::must_cast<luci::CircleLeakyRelu *>(node);
153 relu->features(input);
154 break;
155 }
156 case luci::CircleOpcode::TRANSPOSE_CONV:
157 {
158 auto tconv = loco::must_cast<luci::CircleTransposeConv *>(node);
159 tconv->outBackprop(input);
160 break;
161 }
162 default:
163 {
164 throw std::runtime_error("(set_input) NYI operator: " + node->name());
165 }
166 }
167}

References luci::CircleNode::name(), and luci::CircleNode::opcode().

◆ to_scale()

luci::CircleCustom * fme_apply::to_scale ( loco::Node node)

Definition at line 22 of file Support.Cast.cpp.

23{
24 auto scale = dynamic_cast<luci::CircleCustom *>(node);
25 if (not scale)
26 return nullptr;
27
28 if (scale->custom_code() != "scale")
29 return nullptr;
30
31 // TODO Return false?
32 assert(scale->numInputs() == 2); // FIX_PreScale_UNLESS
33
34 return scale;
35}
CUSTOM in Circle.