ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Context.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Context.h"
18
19#include "Convert.h"
20
21#include <coco/IR/Data.h>
22#include <coco/IR/Module.h>
23
25#include <schema_generated.h>
26
27#include <map>
28#include <sstream>
29
30using namespace nncc::core::ADT;
31
32namespace tflimport
33{
34
35void TensorContext::prepare(const tflite::SubGraph *graph)
36{
37 for (uint32_t tensor_id = 0; tensor_id < graph->tensors()->size(); ++tensor_id)
38 {
39 auto const tensor_info = graph->tensors()->Get(tensor_id);
40 auto const tensor_name = tensor_info->name()->str();
41 auto const tensor_shape = as_tensor_shape(tensor_info->shape());
42 auto const tensor_type = tensor_info->type();
43
44 _name_ctx[tensor_id] = tensor_name;
45 _shape_ctx[tensor_id] = tensor_shape;
46 _type_ctx[tensor_id] = tensor_type;
47 }
48}
49
52{
53 for (const tflite::OperatorCode *opcode : *opcodes)
54 {
55 _opcodes.push_back(opcode);
56 }
57}
58
59tflite::BuiltinOperator TflOpCodeContext::builtin_code(const tflite::Operator *op) const
60{
61 uint32_t index = op->opcode_index();
62 assert(index < _opcodes.size());
63 const tflite::OperatorCode *opcode = _opcodes.at(index);
64 return opcode->builtin_code();
65}
66
67std::string TflOpCodeContext::opcode_name(const tflite::Operator *op) const
68{
69 uint32_t index = op->opcode_index();
70 assert(index < _opcodes.size());
71 const tflite::OperatorCode *opcode = _opcodes.at(index);
72
73 if (!is_valid(opcode))
74 {
75 std::ostringstream oss;
76 oss << "(invalid: " << index << ")";
77 return oss.str();
78 }
79
80 if (is_custom(opcode))
81 {
82 if (!opcode->custom_code())
83 return "(invalid custom)";
84
85 return opcode->custom_code()->c_str();
86 }
87
88 tflite::BuiltinOperator code = opcode->builtin_code();
89 return EnumNameBuiltinOperator(code);
90}
91
92bool TflOpCodeContext::is_valid(const tflite::OperatorCode *opcode)
93{
94 tflite::BuiltinOperator code = opcode->builtin_code();
95 return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
96}
97
98bool TflOpCodeContext::is_custom(const tflite::OperatorCode *opcode)
99{
100 tflite::BuiltinOperator code = opcode->builtin_code();
101 return (code == tflite::BuiltinOperator_CUSTOM);
102}
103
104TflBufferContext::TflBufferContext(const tflite::Model *tfl_model)
105{
107
108 tfl_buffers = tfl_model->buffers();
109
110 for (uint32_t buffer_id = 0; buffer_id < tfl_buffers->size(); ++buffer_id)
111 {
112 _buffer_ctx[buffer_id] = (*tfl_buffers)[buffer_id];
113 }
114}
115
116} // namespace tflimport
uoffset_t size() const
void prepare(const tflite::SubGraph *graph)
Definition Context.cpp:35
TflBufferContext(const tflite::Model *tfl_model)
Definition Context.cpp:104
TflOpCodeContext(const flatbuffers::Vector< flatbuffers::Offset< tflite::OperatorCode > > *opcodes)
Definition Context.cpp:50
tflite::BuiltinOperator builtin_code(const tflite::Operator *op) const
Returns BuiltinOperator value of the operator.
Definition Context.cpp:59
std::string opcode_name(const tflite::Operator *op) const
Returns human readable name of the operator code of the operator.
Definition Context.cpp:67
static bool is_custom(const tflite::OperatorCode *opcode)
Definition Context.cpp:98
static bool is_valid(const tflite::OperatorCode *opcode)
Definition Context.cpp:92
tensor::Shape as_tensor_shape(const flatbuffers::Vector< int32_t > *shape)
Converts flatbuffers::Vector to nncc::core::ADT::tensor::Shape.
Definition Convert.cpp:42
int32_t size[5]
Definition Slice.cpp:35