37{
38
39py::array numpyArray(
const Tensor *tensor)
40{
41 assert(tensor != nullptr);
42
43 const auto tensor_shape =
tensor->shape();
44
46 std::vector<uint32_t> shape(tensor_shape.num_dims());
47 for (int i = 0; i < tensor_shape.num_dims(); i++)
48 {
50
51 shape[i] = tensor_shape.
dim(i);
53 }
54
56 return py::none();
57
58 switch (
tensor->element_type())
59 {
60 case loco::DataType::FLOAT32:
61 return py::array_t<float, py::array::c_style>(shape,
tensor->data<
float>());
62 case loco::DataType::S16:
63 return py::array_t<int16_t, py::array::c_style>(shape,
tensor->data<int16_t>());
64 case loco::DataType::S32:
65 return py::array_t<int32_t, py::array::c_style>(shape,
tensor->data<int32_t>());
66 case loco::DataType::S64:
67 return py::array_t<int64_t, py::array::c_style>(shape,
tensor->data<int64_t>());
68 case loco::DataType::U8:
69 return py::array_t<uint8_t, py::array::c_style>(shape,
tensor->data<uint8_t>());
70 case loco::DataType::BOOL:
71 return py::array_t<bool, py::array::c_style>(shape,
tensor->data<
bool>());
72 default:
73 throw std::runtime_error("Unsupported data type");
74 }
75}
76
77py::dict quantparam(
const Tensor *tensor)
78{
79 assert(tensor != nullptr);
80
82 auto zp =
tensor->zero_points();
83
84 py::list py_scale;
86 {
87 py_scale.append(s);
88 }
89
90 py::list py_zp;
91 for (auto z : zp)
92 {
93 py_zp.append(z);
94 }
95
96 auto quantparam = py::dict("scale"_a = py_scale, "zero_point"_a = py_zp,
97 "quantized_dimension"_a =
tensor->quantized_dimension());
98 return quantparam;
99}
100
101}
102
104{
105
106py::object
none() {
return py::none(); }
107
110{
111 assert(node != nullptr);
112 assert(interpreter != nullptr);
113
114 std::vector<py::dict>
inputs;
115 for (uint32_t i = 0; i < node->
arity(); ++i)
116 {
119
120
121 if (circle_node->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
122 continue;
123
124 auto py_input =
125 py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(input_tensor),
126 "quantparam"_a = quantparam(input_tensor),
127 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
128 inputs.push_back(py_input);
129 }
131}
132
135{
136 std::vector<py::dict> outputs;
138 {
141
142 auto opcode_str =
toString(circle_node->opcode());
143
144
145
146 THROW_UNLESS(opcode_str.substr(opcode_str.length() - 3) ==
"Out",
147 "Invalid output detected in " + node->
name());
148
149 auto py_output =
150 py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(output_tensor),
151 "quantparam"_a = quantparam(output_tensor),
152 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
153 outputs.push_back(py_output);
154 }
155 return outputs;
156}
157
158
160{
161 assert(node != nullptr);
162 assert(interpreter != nullptr);
163
165
167
168 auto py_output = py::dict(
"name"_a = node->
name(),
"data"_a = numpyArray(tensor),
169 "quantparam"_a = quantparam(tensor),
170 "is_const"_a = node->
opcode() == luci::CircleOpcode::CIRCLECONST);
171 return py_output;
172}
173
175{
177 {
178
179
180 case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
181 case luci::CircleOpcode::CUSTOM:
182 case luci::CircleOpcode::IF:
183 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
184 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
185 case luci::CircleOpcode::SPLIT:
186 case luci::CircleOpcode::SPLIT_V:
187 case luci::CircleOpcode::TOPK_V2:
188 case luci::CircleOpcode::UNIQUE:
189 case luci::CircleOpcode::UNPACK:
190 return true;
191 default:
192 return false;
193 }
194}
195
196}
197
198#undef THROW_UNLESS
#define THROW_UNLESS(COND, MSG)
virtual Node * arg(uint32_t N) const =0
Access N-th argument node.
virtual uint32_t arity(void) const =0
Return the number of arguments.
const Dimension & dim(uint32_t axis) const
#define THROW_UNLESS(COND, MSG)
std::vector< py::dict > inputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
const std::string toString(luci::CircleOpcode opcode)
std::vector< py::dict > outputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
bool multi_out_node(const luci::CircleNode *node)
std::set< Node * > succs(const Node *node)
Enumerate all the successors of a given node.
NodeName name(void) const
virtual CircleOpcode opcode(void) const =0