ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tflite_importer.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "tflite_importer.h"
18#include "tflite_op_creator.h"
19#include "schema_generated.h"
20
21#include "mir/TensorVariant.h"
22#include "mir/ops/ConstantOp.h"
23#include "mir/ops/OutputOp.h"
24
25#include <fstream>
26#include <memory>
27#include <utility>
28#include <vector>
29#include <set>
30
31namespace mir_tflite
32{
33
34namespace
35{
36
37class TfliteImporter
38{
39public:
40 explicit TfliteImporter(std::string filename);
41
43 std::unique_ptr<mir::Graph> importModel();
44
45 ~TfliteImporter();
46
47private:
48 std::string _filename;
49 std::unique_ptr<tflite::ModelT> _model;
50
51 std::unique_ptr<mir::Graph> _graph;
52 std::unique_ptr<TFLiteOpCreator> _opCreator;
53
54 // Maps TFLite tensors indices to corresponding MIR operation outputs.
55 std::vector<mir::Operation::Output *> _tensorMap;
56
57 void import();
58
59 void walkModel(const tflite::ModelT *model);
60
61 void walkSubgraph(const tflite::SubGraphT *subgraph);
62
63 void walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op);
64
69 void collectUnsupportedOps();
70
74 std::vector<mir::Operation::Output *> getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
75 const tflite::OperatorT *op);
76};
77
78TfliteImporter::TfliteImporter(std::string filename) : _filename(std::move(filename))
79{
80 _graph = std::make_unique<mir::Graph>();
81 _opCreator = std::make_unique<TFLiteOpCreator>(_graph.get());
82}
83
84TfliteImporter::~TfliteImporter() = default;
85
86void TfliteImporter::import()
87{
88 std::ifstream stream(_filename, std::ios::in | std::ios::binary);
89 if (stream.fail())
90 throw std::runtime_error("Couldn't open file \"" + _filename + "\".");
91
92 std::vector<char> model_buffer((std::istreambuf_iterator<char>(stream)),
93 std::istreambuf_iterator<char>());
94
95 if (stream.fail())
96 throw std::runtime_error("Couldn't read file \"" + _filename + "\".");
97
98 flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(model_buffer.data()),
99 model_buffer.size());
100
101 if (!tflite::VerifyModelBuffer(verifier))
102 throw std::runtime_error("Could not load model: " + _filename + "\n");
103
104 _model = tflite::UnPackModel(model_buffer.data());
105}
106
107static const std::set<tflite::BuiltinOperator> supportedOperators = {
108 tflite::BuiltinOperator_ADD,
109 tflite::BuiltinOperator_AVERAGE_POOL_2D,
110 tflite::BuiltinOperator_CONCATENATION,
111 tflite::BuiltinOperator_CONV_2D,
112 tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
113 tflite::BuiltinOperator_DIV,
114 tflite::BuiltinOperator_FULLY_CONNECTED,
115 tflite::BuiltinOperator_HARD_SWISH,
116 tflite::BuiltinOperator_LEAKY_RELU,
117 tflite::BuiltinOperator_LOGISTIC,
118 tflite::BuiltinOperator_MAX_POOL_2D,
119 tflite::BuiltinOperator_MAXIMUM,
120 tflite::BuiltinOperator_MEAN,
121 tflite::BuiltinOperator_MUL,
122 tflite::BuiltinOperator_PAD,
123 tflite::BuiltinOperator_RELU,
124 tflite::BuiltinOperator_RELU6,
125 tflite::BuiltinOperator_RESHAPE,
126 tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
127 tflite::BuiltinOperator_RSQRT,
128 tflite::BuiltinOperator_SHAPE,
129 tflite::BuiltinOperator_SLICE,
130 tflite::BuiltinOperator_SOFTMAX,
131 tflite::BuiltinOperator_SQRT,
132 tflite::BuiltinOperator_SQUARED_DIFFERENCE,
133 tflite::BuiltinOperator_SQUEEZE,
134 tflite::BuiltinOperator_STRIDED_SLICE,
135 tflite::BuiltinOperator_SUB,
136 tflite::BuiltinOperator_TANH,
137 tflite::BuiltinOperator_TRANSPOSE,
138 tflite::BuiltinOperator_TRANSPOSE_CONV,
139};
140
141void TfliteImporter::collectUnsupportedOps()
142{
143 std::set<std::string> errors;
144 for (const auto &subgraph : _model->subgraphs)
145 for (const auto &op : subgraph->operators)
146 {
147 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
148 if (supportedOperators.find(opcode) == supportedOperators.end())
149 {
150 if (opcode <= tflite::BuiltinOperator_MAX)
151 errors.insert(std::string(EnumNameBuiltinOperator(opcode)) + ": unsupported operator");
152 else
153 errors.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
154 }
155 }
156
157 if (!errors.empty())
158 {
159 std::string msg("NNC can't load model. Detected problems:");
160 for (const auto &e : errors)
161 msg.append("\n * " + e);
162 throw std::runtime_error(msg);
163 }
164}
165
166std::unique_ptr<mir::Graph> TfliteImporter::importModel()
167{
168 import();
169 collectUnsupportedOps();
170 walkModel(_model.get());
171 return std::move(_graph);
172}
173
174void TfliteImporter::walkModel(const tflite::ModelT *model)
175{
176 for (const auto &subgraph : model->subgraphs)
177 walkSubgraph(subgraph.get());
178}
179
180mir::DataType convertElementType(tflite::TensorType type)
181{
182 switch (type)
183 {
184 case tflite::TensorType_INT32:
185 return mir::DataType::INT32;
186 case tflite::TensorType_FLOAT32:
187 return mir::DataType::FLOAT32;
188 case tflite::TensorType_INT64:
189 return mir::DataType::INT64;
190 case tflite::TensorType_UINT8:
191 return mir::DataType::UINT8;
192 default:
193 throw std::runtime_error(std::string("Unsupported tensor type: ") + EnumNameTensorType(type));
194 }
195}
196
197mir::TensorType getMirTensorType(const tflite::TensorT &tensor)
198{
199 mir::DataType element_type = convertElementType(tensor.type);
200
201 mir::Shape shape(tensor.shape.size());
202 for (std::size_t i = 0; i < tensor.shape.size(); ++i)
203 {
204 shape.dim(i) = tensor.shape[i];
205 }
206
207 if (tensor.quantization != nullptr)
208 {
209 const tflite::QuantizationParametersT &params = *tensor.quantization;
210
211 if (params.details.type != tflite::QuantizationDetails_NONE)
212 throw std::runtime_error("Custom quantization is not supported.");
213
214 // Empty parameters mean no quantization at all.
215 if (params.scale.empty() && params.zero_point.empty())
216 return mir::TensorType{element_type, shape};
217
218 if (params.scale.size() != 1 || params.zero_point.size() != 1)
219 throw std::runtime_error("Non-scalar quantization is not supported.");
220
221 mir::AffineQuantization quantization{params.scale[0], static_cast<int>(params.zero_point[0])};
222
223 return mir::TensorType{element_type, shape, quantization};
224 }
225 else
226 {
227 return mir::TensorType{element_type, shape};
228 }
229}
230
231void TfliteImporter::walkSubgraph(const tflite::SubGraphT *subgraph)
232{
233 _tensorMap.assign(subgraph->tensors.size(), nullptr);
234
235 for (const auto input_tensor_index : subgraph->inputs)
236 {
237 const tflite::TensorT &tensor = *subgraph->tensors[input_tensor_index];
238
239 mir::TensorType input_type = getMirTensorType(tensor);
240 auto input = _graph->create<mir::ops::InputOp>(input_type)->getOutput(0);
241 input->setName(tensor.name);
242
243 assert(_tensorMap[input_tensor_index] == nullptr);
244 _tensorMap[input_tensor_index] = input;
245 }
246
247 for (const auto &op : subgraph->operators)
248 {
249 walkOperator(subgraph, op.get());
250 }
251
252 for (const auto output_tensor_index : subgraph->outputs)
253 {
254 auto output = _tensorMap[output_tensor_index];
255 _graph->create<mir::ops::OutputOp>(output);
256 }
257}
258
259void TfliteImporter::walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op)
260{
261 std::vector<mir::Operation::Output *> inputs = getMIRInputsForOperator(subgraph, op);
262 std::vector<mir::Operation::Output *> outputs;
263
264 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
265 switch (opcode)
266 {
267 case tflite::BuiltinOperator_CONV_2D:
268 outputs = _opCreator->convertConv2D(op->builtin_options.AsConv2DOptions(), inputs);
269 break;
270 case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
271 outputs =
272 _opCreator->convertDepthwiseConv2D(op->builtin_options.AsDepthwiseConv2DOptions(), inputs);
273 break;
274 case tflite::BuiltinOperator_MAX_POOL_2D:
275 outputs = _opCreator->convertMaxPool2D(op->builtin_options.AsPool2DOptions(), inputs);
276 break;
277 case tflite::BuiltinOperator_AVERAGE_POOL_2D:
278 outputs = _opCreator->convertAveragePool2D(op->builtin_options.AsPool2DOptions(), inputs);
279 break;
280 case tflite::BuiltinOperator_CONCATENATION:
281 outputs =
282 _opCreator->convertConcatenation(op->builtin_options.AsConcatenationOptions(), inputs);
283 break;
284 case tflite::BuiltinOperator_RESHAPE:
285 outputs = _opCreator->convertReshape(op->builtin_options.AsReshapeOptions(), inputs);
286 break;
287 case tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
288 outputs = _opCreator->convertResizeNearestNeighbor(
289 op->builtin_options.AsResizeNearestNeighborOptions(), inputs);
290 break;
291 case tflite::BuiltinOperator_MEAN:
292 outputs = _opCreator->convertMean(op->builtin_options.AsReducerOptions(), inputs);
293 break;
294 case tflite::BuiltinOperator_FULLY_CONNECTED:
295 outputs =
296 _opCreator->convertFullyConnected(op->builtin_options.AsFullyConnectedOptions(), inputs);
297 break;
298 case tflite::BuiltinOperator_SOFTMAX:
299 outputs = _opCreator->convertSoftmax(op->builtin_options.AsSoftmaxOptions(), inputs);
300 break;
301 case tflite::BuiltinOperator_SLICE:
302 outputs = _opCreator->convertSlice(op->builtin_options.AsSliceOptions(), inputs);
303 break;
304 case tflite::BuiltinOperator_SQUEEZE:
305 outputs = _opCreator->convertSqueeze(op->builtin_options.AsSqueezeOptions(), inputs);
306 break;
307 case tflite::BuiltinOperator_LOGISTIC:
308 outputs = _opCreator->convertLogistic(inputs);
309 break;
310 case tflite::BuiltinOperator_RSQRT:
311 outputs = _opCreator->convertRsqrt(inputs);
312 break;
313 case tflite::BuiltinOperator_SQRT:
314 outputs = _opCreator->convertSqrt(inputs);
315 break;
316 case tflite::BuiltinOperator_ADD:
317 outputs = _opCreator->convertAdd(op->builtin_options.AsAddOptions(), inputs);
318 break;
319 case tflite::BuiltinOperator_SUB:
320 outputs = _opCreator->convertSub(op->builtin_options.AsSubOptions(), inputs);
321 break;
322 case tflite::BuiltinOperator_MUL:
323 outputs = _opCreator->convertMul(op->builtin_options.AsMulOptions(), inputs);
324 break;
325 case tflite::BuiltinOperator_DIV:
326 outputs = _opCreator->convertDiv(op->builtin_options.AsDivOptions(), inputs);
327 break;
328 case tflite::BuiltinOperator_MAXIMUM:
329 outputs = _opCreator->convertMax(inputs);
330 break;
331 case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
332 outputs = _opCreator->convertSquaredDifference(inputs);
333 break;
334 case tflite::BuiltinOperator_TRANSPOSE_CONV:
335 outputs =
336 _opCreator->convertTransposeConv(op->builtin_options.AsTransposeConvOptions(), inputs);
337 break;
338 case tflite::BuiltinOperator_PAD:
339 outputs = _opCreator->convertPad(op->builtin_options.AsPadOptions(), inputs);
340 break;
341 case tflite::BuiltinOperator_TANH:
342 outputs = _opCreator->convertTanh(inputs);
343 break;
344 case tflite::BuiltinOperator_RELU:
345 outputs = _opCreator->convertReLU(inputs);
346 break;
347 case tflite::BuiltinOperator_RELU6:
348 outputs = _opCreator->convertReLU6(inputs);
349 break;
350 case tflite::BuiltinOperator_TRANSPOSE:
351 outputs = _opCreator->convertTranspose(op->builtin_options.AsTransposeOptions(), inputs);
352 break;
353 case tflite::BuiltinOperator_STRIDED_SLICE:
354 outputs =
355 _opCreator->convertStridedSlice(op->builtin_options.AsStridedSliceOptions(), inputs);
356 break;
357 case tflite::BuiltinOperator_LEAKY_RELU:
358 outputs = _opCreator->convertLeakyReLU(op->builtin_options.AsLeakyReluOptions(), inputs);
359 break;
360 case tflite::BuiltinOperator_SHAPE:
361 outputs = _opCreator->convertShape(op->builtin_options.AsShapeOptions(), inputs);
362 break;
363 case tflite::BuiltinOperator_HARD_SWISH:
364 outputs = _opCreator->convertHardSwish(op->builtin_options.AsHardSwishOptions(), inputs);
365 break;
366 default:
367 assert(false && "All unsupported types should have been found before this pass.");
368 }
369
370 assert(outputs.size() == op->outputs.size());
371 for (std::size_t i = 0; i < op->outputs.size(); ++i)
372 {
373 const auto tensor_index = op->outputs[i];
374 const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
375
376 mir::TensorType output_type = getMirTensorType(tensor);
377
378 // The type should have been inferred correctly, except for quantization information.
379 assert(outputs[i]->getType().getElementType() == output_type.getElementType() &&
380 outputs[i]->getType().getShape() == output_type.getShape());
381
382 outputs[i]->setName(tensor.name);
383 outputs[i]->setType(output_type);
384
385 assert(_tensorMap[tensor_index] == nullptr);
386 _tensorMap[tensor_index] = outputs[i];
387 }
388}
389
390std::vector<mir::Operation::Output *>
391TfliteImporter::getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
392 const tflite::OperatorT *op)
393{
394 std::vector<mir::Operation::Output *> inputs;
395
396 for (const auto tensor_index : op->inputs)
397 {
398 const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
399 const tflite::BufferT &buffer = *_model->buffers[tensor.buffer];
400 if (!buffer.data.empty())
401 {
402 assert(_tensorMap[tensor_index] == nullptr);
403 mir::TensorType type = getMirTensorType(tensor);
404 mir::TensorVariant mir_tensor{type, buffer.data.data()};
405 inputs.emplace_back(_graph->create<mir::ops::ConstantOp>(mir_tensor)->getOutput(0));
406 }
407 else
408 {
409 assert(_tensorMap[tensor_index] != nullptr);
410 // By this point every input for the operation "op" should have corresponding
411 // Model IR operations that output its inputs. This assumption is provided by the fact
412 // that TFLite format specifies all operations in the execution order.
413 inputs.emplace_back(_tensorMap[tensor_index]);
414 }
415 }
416
417 return inputs;
418}
419
420} // namespace
421
422std::unique_ptr<mir::Graph> loadModel(std::string filename)
423{
424 TfliteImporter importer(std::move(filename));
425 return importer.importModel();
426}
427
428} // namespace mir_tflite
Output * getOutput(std::size_t index)
Definition Operation.h:149
KnobTrait< K >::ValueType get(void)
type
Definition infer.py:18
constexpr DataType getElementType()
Definition TestUtils.h:134
DataType
Definition DataType.h:27
NNFW_TYPE getType(const char *type="")