20#include <luci_interpreter/core/Tensor.h>
24#include <pybind11/numpy.h>
30namespace py = pybind11;
31using namespace py::literals;
33#define THROW_UNLESS(COND, MSG) \
35 throw std::runtime_error(MSG);
40py::array numpyArray(
const Tensor *tensor)
42 assert(tensor !=
nullptr);
44 const auto tensor_shape = tensor->shape();
47 std::vector<uint32_t> shape(tensor_shape.num_dims());
48 for (
int i = 0; i < tensor_shape.num_dims(); i++)
50 THROW_UNLESS(tensor_shape.
dim(i) >= 0,
"Negative dimension detected in " + tensor->name());
52 shape[i] = tensor_shape.
dim(i);
59 switch (tensor->element_type())
61 case loco::DataType::FLOAT32:
62 return py::array_t<float, py::array::c_style>(shape, tensor->data<
float>());
63 case loco::DataType::S16:
64 return py::array_t<int16_t, py::array::c_style>(shape, tensor->data<int16_t>());
65 case loco::DataType::S32:
66 return py::array_t<int32_t, py::array::c_style>(shape, tensor->data<int32_t>());
67 case loco::DataType::S64:
68 return py::array_t<int64_t, py::array::c_style>(shape, tensor->data<int64_t>());
69 case loco::DataType::U8:
70 return py::array_t<uint8_t, py::array::c_style>(shape, tensor->data<uint8_t>());
71 case loco::DataType::BOOL:
72 return py::array_t<bool, py::array::c_style>(shape, tensor->data<
bool>());
74 throw std::runtime_error(
"Unsupported data type");
78py::dict quantparam(
const Tensor *tensor)
80 assert(tensor !=
nullptr);
82 auto scale = tensor->scales();
83 auto zp = tensor->zero_points();
97 auto quantparam = py::dict(
"scale"_a = py_scale,
"zero_point"_a = py_zp,
98 "quantized_dimension"_a = tensor->quantized_dimension());
107py::object
none() {
return py::none(); }
112 assert(node !=
nullptr);
113 assert(interpreter !=
nullptr);
115 std::vector<py::dict> inputs;
116 for (uint32_t i = 0; i < node->
arity(); ++i)
118 const auto input_tensor = interpreter->getTensor(node->
arg(i));
122 if (circle_node->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
126 py::dict(
"name"_a = circle_node->name(),
"data"_a = numpyArray(input_tensor),
127 "quantparam"_a = quantparam(input_tensor),
128 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
129 inputs.push_back(py_input);
137 std::vector<py::dict> outputs;
140 const auto output_tensor = interpreter->getTensor(succ);
143 auto opcode_str =
toString(circle_node->opcode());
147 THROW_UNLESS(opcode_str.substr(opcode_str.length() - 3) ==
"Out",
148 "Invalid output detected in " + node->
name());
151 py::dict(
"name"_a = circle_node->name(),
"data"_a = numpyArray(output_tensor),
152 "quantparam"_a = quantparam(output_tensor),
153 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
154 outputs.push_back(py_output);
162 assert(node !=
nullptr);
163 assert(interpreter !=
nullptr);
165 const auto tensor = interpreter->getTensor(node);
169 auto py_output = py::dict(
"name"_a = node->
name(),
"data"_a = numpyArray(tensor),
170 "quantparam"_a = quantparam(tensor),
171 "is_const"_a = node->
opcode() == luci::CircleOpcode::CIRCLECONST);
181 case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
182 case luci::CircleOpcode::CUSTOM:
183 case luci::CircleOpcode::IF:
184 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
185 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
186 case luci::CircleOpcode::SPLIT:
187 case luci::CircleOpcode::SPLIT_V:
188 case luci::CircleOpcode::TOPK_V2:
189 case luci::CircleOpcode::UNIQUE:
190 case luci::CircleOpcode::UNPACK:
#define THROW_UNLESS(COND, MSG)
virtual Node * arg(uint32_t N) const =0
Access N-th argument node.
virtual uint32_t arity(void) const =0
Return the number of arguments.
const Dimension & dim(uint32_t axis) const
#define THROW_UNLESS(COND, MSG)
std::vector< py::dict > inputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
const std::string toString(luci::CircleOpcode opcode)
std::vector< py::dict > outputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
bool multi_out_node(const luci::CircleNode *node)
std::set< Node * > succs(const Node *node)
Enumerate all the successors of a given node.
This file contains utility macro.
NodeName name(void) const
virtual CircleOpcode opcode(void) const =0