19#include "schema_generated.h"
40 explicit TfliteImporter(std::string filename);
43 std::unique_ptr<mir::Graph> importModel();
48 std::string _filename;
49 std::unique_ptr<tflite::ModelT> _model;
51 std::unique_ptr<mir::Graph> _graph;
52 std::unique_ptr<TFLiteOpCreator> _opCreator;
55 std::vector<mir::Operation::Output *> _tensorMap;
59 void walkModel(
const tflite::ModelT *model);
61 void walkSubgraph(
const tflite::SubGraphT *subgraph);
63 void walkOperator(
const tflite::SubGraphT *subgraph,
const tflite::OperatorT *op);
69 void collectUnsupportedOps();
74 std::vector<mir::Operation::Output *> getMIRInputsForOperator(
const tflite::SubGraphT *subgraph,
75 const tflite::OperatorT *op);
78TfliteImporter::TfliteImporter(std::string filename) : _filename(
std::move(filename))
80 _graph = std::make_unique<mir::Graph>();
81 _opCreator = std::make_unique<TFLiteOpCreator>(_graph.get());
84TfliteImporter::~TfliteImporter() =
default;
86void TfliteImporter::import()
88 std::ifstream stream(_filename, std::ios::in | std::ios::binary);
90 throw std::runtime_error(
"Couldn't open file \"" + _filename +
"\".");
92 std::vector<char> model_buffer((std::istreambuf_iterator<char>(stream)),
93 std::istreambuf_iterator<char>());
96 throw std::runtime_error(
"Couldn't read file \"" + _filename +
"\".");
98 flatbuffers::Verifier verifier(
reinterpret_cast<const std::uint8_t *
>(model_buffer.data()),
101 if (!tflite::VerifyModelBuffer(verifier))
102 throw std::runtime_error(
"Could not load model: " + _filename +
"\n");
104 _model = tflite::UnPackModel(model_buffer.data());
107static const std::set<tflite::BuiltinOperator> supportedOperators = {
108 tflite::BuiltinOperator_ADD,
109 tflite::BuiltinOperator_AVERAGE_POOL_2D,
110 tflite::BuiltinOperator_CONCATENATION,
111 tflite::BuiltinOperator_CONV_2D,
112 tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
113 tflite::BuiltinOperator_DIV,
114 tflite::BuiltinOperator_FULLY_CONNECTED,
115 tflite::BuiltinOperator_HARD_SWISH,
116 tflite::BuiltinOperator_LEAKY_RELU,
117 tflite::BuiltinOperator_LOGISTIC,
118 tflite::BuiltinOperator_MAX_POOL_2D,
119 tflite::BuiltinOperator_MAXIMUM,
120 tflite::BuiltinOperator_MEAN,
121 tflite::BuiltinOperator_MUL,
122 tflite::BuiltinOperator_PAD,
123 tflite::BuiltinOperator_RELU,
124 tflite::BuiltinOperator_RELU6,
125 tflite::BuiltinOperator_RESHAPE,
126 tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
127 tflite::BuiltinOperator_RSQRT,
128 tflite::BuiltinOperator_SHAPE,
129 tflite::BuiltinOperator_SLICE,
130 tflite::BuiltinOperator_SOFTMAX,
131 tflite::BuiltinOperator_SQRT,
132 tflite::BuiltinOperator_SQUARED_DIFFERENCE,
133 tflite::BuiltinOperator_SQUEEZE,
134 tflite::BuiltinOperator_STRIDED_SLICE,
135 tflite::BuiltinOperator_SUB,
136 tflite::BuiltinOperator_TANH,
137 tflite::BuiltinOperator_TRANSPOSE,
138 tflite::BuiltinOperator_TRANSPOSE_CONV,
141void TfliteImporter::collectUnsupportedOps()
143 std::set<std::string> errors;
144 for (
const auto &subgraph : _model->subgraphs)
145 for (const auto &op : subgraph->operators)
147 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
148 if (supportedOperators.find(opcode) == supportedOperators.end())
150 if (opcode <= tflite::BuiltinOperator_MAX)
151 errors.insert(std::string(EnumNameBuiltinOperator(opcode)) +
": unsupported operator");
153 errors.insert(std::to_string(opcode) +
": unsuppored in tflite custom opcode");
159 std::string msg(
"NNC can't load model. Detected problems:");
160 for (
const auto &e : errors)
161 msg.append(
"\n * " + e);
162 throw std::runtime_error(msg);
166std::unique_ptr<mir::Graph> TfliteImporter::importModel()
169 collectUnsupportedOps();
170 walkModel(_model.get());
171 return std::move(_graph);
174void TfliteImporter::walkModel(
const tflite::ModelT *model)
176 for (
const auto &subgraph :
model->subgraphs)
177 walkSubgraph(subgraph.
get());
184 case tflite::TensorType_INT32:
185 return mir::DataType::INT32;
186 case tflite::TensorType_FLOAT32:
187 return mir::DataType::FLOAT32;
188 case tflite::TensorType_INT64:
189 return mir::DataType::INT64;
190 case tflite::TensorType_UINT8:
191 return mir::DataType::UINT8;
193 throw std::runtime_error(std::string(
"Unsupported tensor type: ") + EnumNameTensorType(type));
202 for (std::size_t i = 0; i <
tensor.shape.size(); ++i)
204 shape.dim(i) =
tensor.shape[i];
207 if (
tensor.quantization !=
nullptr)
209 const tflite::QuantizationParametersT ¶ms = *
tensor.quantization;
211 if (params.details.type != tflite::QuantizationDetails_NONE)
212 throw std::runtime_error(
"Custom quantization is not supported.");
215 if (params.scale.empty() && params.zero_point.empty())
218 if (params.scale.size() != 1 || params.zero_point.size() != 1)
219 throw std::runtime_error(
"Non-scalar quantization is not supported.");
231void TfliteImporter::walkSubgraph(
const tflite::SubGraphT *subgraph)
233 _tensorMap.assign(subgraph->tensors.size(),
nullptr);
235 for (
const auto input_tensor_index : subgraph->
inputs)
237 const tflite::TensorT &
tensor = *subgraph->tensors[input_tensor_index];
243 assert(_tensorMap[input_tensor_index] ==
nullptr);
244 _tensorMap[input_tensor_index] =
input;
247 for (
const auto &op : subgraph->operators)
249 walkOperator(subgraph, op.get());
252 for (
const auto output_tensor_index : subgraph->outputs)
254 auto output = _tensorMap[output_tensor_index];
259void TfliteImporter::walkOperator(
const tflite::SubGraphT *subgraph,
const tflite::OperatorT *op)
261 std::vector<mir::Operation::Output *>
inputs = getMIRInputsForOperator(subgraph, op);
262 std::vector<mir::Operation::Output *> outputs;
264 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
267 case tflite::BuiltinOperator_CONV_2D:
268 outputs = _opCreator->convertConv2D(op->builtin_options.AsConv2DOptions(), inputs);
270 case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
272 _opCreator->convertDepthwiseConv2D(op->builtin_options.AsDepthwiseConv2DOptions(), inputs);
274 case tflite::BuiltinOperator_MAX_POOL_2D:
275 outputs = _opCreator->convertMaxPool2D(op->builtin_options.AsPool2DOptions(), inputs);
277 case tflite::BuiltinOperator_AVERAGE_POOL_2D:
278 outputs = _opCreator->convertAveragePool2D(op->builtin_options.AsPool2DOptions(), inputs);
280 case tflite::BuiltinOperator_CONCATENATION:
282 _opCreator->convertConcatenation(op->builtin_options.AsConcatenationOptions(), inputs);
284 case tflite::BuiltinOperator_RESHAPE:
285 outputs = _opCreator->convertReshape(op->builtin_options.AsReshapeOptions(), inputs);
287 case tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
288 outputs = _opCreator->convertResizeNearestNeighbor(
289 op->builtin_options.AsResizeNearestNeighborOptions(), inputs);
291 case tflite::BuiltinOperator_MEAN:
292 outputs = _opCreator->convertMean(op->builtin_options.AsReducerOptions(), inputs);
294 case tflite::BuiltinOperator_FULLY_CONNECTED:
296 _opCreator->convertFullyConnected(op->builtin_options.AsFullyConnectedOptions(), inputs);
298 case tflite::BuiltinOperator_SOFTMAX:
299 outputs = _opCreator->convertSoftmax(op->builtin_options.AsSoftmaxOptions(), inputs);
301 case tflite::BuiltinOperator_SLICE:
302 outputs = _opCreator->convertSlice(op->builtin_options.AsSliceOptions(), inputs);
304 case tflite::BuiltinOperator_SQUEEZE:
305 outputs = _opCreator->convertSqueeze(op->builtin_options.AsSqueezeOptions(), inputs);
307 case tflite::BuiltinOperator_LOGISTIC:
308 outputs = _opCreator->convertLogistic(inputs);
310 case tflite::BuiltinOperator_RSQRT:
311 outputs = _opCreator->convertRsqrt(inputs);
313 case tflite::BuiltinOperator_SQRT:
314 outputs = _opCreator->convertSqrt(inputs);
316 case tflite::BuiltinOperator_ADD:
317 outputs = _opCreator->convertAdd(op->builtin_options.AsAddOptions(), inputs);
319 case tflite::BuiltinOperator_SUB:
320 outputs = _opCreator->convertSub(op->builtin_options.AsSubOptions(), inputs);
322 case tflite::BuiltinOperator_MUL:
323 outputs = _opCreator->convertMul(op->builtin_options.AsMulOptions(), inputs);
325 case tflite::BuiltinOperator_DIV:
326 outputs = _opCreator->convertDiv(op->builtin_options.AsDivOptions(), inputs);
328 case tflite::BuiltinOperator_MAXIMUM:
329 outputs = _opCreator->convertMax(inputs);
331 case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
332 outputs = _opCreator->convertSquaredDifference(inputs);
334 case tflite::BuiltinOperator_TRANSPOSE_CONV:
336 _opCreator->convertTransposeConv(op->builtin_options.AsTransposeConvOptions(), inputs);
338 case tflite::BuiltinOperator_PAD:
339 outputs = _opCreator->convertPad(op->builtin_options.AsPadOptions(), inputs);
341 case tflite::BuiltinOperator_TANH:
342 outputs = _opCreator->convertTanh(inputs);
344 case tflite::BuiltinOperator_RELU:
345 outputs = _opCreator->convertReLU(inputs);
347 case tflite::BuiltinOperator_RELU6:
348 outputs = _opCreator->convertReLU6(inputs);
350 case tflite::BuiltinOperator_TRANSPOSE:
351 outputs = _opCreator->convertTranspose(op->builtin_options.AsTransposeOptions(), inputs);
353 case tflite::BuiltinOperator_STRIDED_SLICE:
355 _opCreator->convertStridedSlice(op->builtin_options.AsStridedSliceOptions(), inputs);
357 case tflite::BuiltinOperator_LEAKY_RELU:
358 outputs = _opCreator->convertLeakyReLU(op->builtin_options.AsLeakyReluOptions(), inputs);
360 case tflite::BuiltinOperator_SHAPE:
361 outputs = _opCreator->convertShape(op->builtin_options.AsShapeOptions(), inputs);
363 case tflite::BuiltinOperator_HARD_SWISH:
364 outputs = _opCreator->convertHardSwish(op->builtin_options.AsHardSwishOptions(), inputs);
367 assert(
false &&
"All unsupported types should have been found before this pass.");
370 assert(outputs.size() == op->outputs.size());
371 for (std::size_t i = 0; i < op->outputs.size(); ++i)
373 const auto tensor_index = op->outputs[i];
374 const tflite::TensorT &
tensor = *subgraph->tensors[tensor_index];
380 outputs[i]->getType().getShape() ==
output_type.getShape());
382 outputs[i]->setName(
tensor.name);
383 outputs[i]->setType(output_type);
385 assert(_tensorMap[tensor_index] ==
nullptr);
386 _tensorMap[tensor_index] = outputs[i];
390std::vector<mir::Operation::Output *>
391TfliteImporter::getMIRInputsForOperator(
const tflite::SubGraphT *subgraph,
392 const tflite::OperatorT *op)
394 std::vector<mir::Operation::Output *>
inputs;
396 for (
const auto tensor_index : op->
inputs)
398 const tflite::TensorT &
tensor = *subgraph->tensors[tensor_index];
399 const tflite::BufferT &buffer = *_model->buffers[
tensor.buffer];
400 if (!buffer.data.empty())
402 assert(_tensorMap[tensor_index] ==
nullptr);
409 assert(_tensorMap[tensor_index] !=
nullptr);
413 inputs.emplace_back(_tensorMap[tensor_index]);
422std::unique_ptr<mir::Graph> loadModel(std::string filename)
424 TfliteImporter importer(std::move(filename));
425 return importer.importModel();
Output * getOutput(std::size_t index)
KnobTrait< K >::ValueType get(void)
constexpr DataType getElementType()
NNFW_TYPE getType(const char *type="")