ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnkit::support::tf::Runner Class Referencefinal

#include <Runner.h>

Public Types

enum class  DataType {
  Unknown , U8 , U16 , U32 ,
  U64 , S8 , S16 , S32 ,
  S64 , FLOAT
}
 

Public Member Functions

 Runner (const char *pb_path)
 
 ~Runner ()
 
bool getTensorShapeFromGraphDef (const std::unique_ptr< ParsedTensor > &tensor, angkor::TensorShape &shape)
 Get tensor shape from GraphDef for input tensor only.
 
bool getTensorDtypeFromGraphDef (const std::unique_ptr< ParsedTensor > &tensor, Runner::DataType &dtype)
 Get tensor data type from GraphDef.
 
void prepareInputs (const std::vector< std::unique_ptr< ParsedTensor > > &inputs, TensorDataMap &data_map)
 
void prepareOutputs (const std::vector< std::unique_ptr< ParsedTensor > > &outputs)
 
void run ()
 
const std::vector< TF_Tensor * > & output ()
 

Detailed Description

Definition at line 37 of file Runner.h.

Member Enumeration Documentation

◆ DataType

Enumerator
Unknown 
U8 
U16 
U32 
U64 
S8 
S16 
S32 
S64 
FLOAT 

Definition at line 40 of file Runner.h.

41 {
42 Unknown, // Unknown type (serves as a default value)
43
44 U8, // 8-bit unsigned integer
45 U16, // 16-bit unsigned integer
46 U32, // 32-bit unsigned integer
47 U64, // 64-bit unsigned integer
48
49 S8, // 8-bit signed integer
50 S16, // 16-bit signed integer
51 S32, // 32-bit signed integer
52 S64, // 64-bit signed integer
53
54 FLOAT, // floating-point
55 };

Constructor & Destructor Documentation

◆ Runner()

nnkit::support::tf::Runner::Runner ( const char *  pb_path)

Definition at line 125 of file Runner.cpp.

126{
127 // initialize member variables
128 _sess = nullptr;
129 _graph = TF_NewGraph();
130 _status = TF_NewStatus();
131
132 // import graph from file
133 TF_Buffer *buffer = build_TFBuffer(pb_path);
134 if (buffer == nullptr)
135 throw std::runtime_error("Can't read buffer from file");
136
137 TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
138
139 TF_GraphImportGraphDef(_graph, buffer, opts, _status);
140
141 TF_DeleteImportGraphDefOptions(opts);
142 TF_DeleteBuffer(buffer);
143
144 if (TF_GetCode(_status) != TF_OK) // TODO Consider wrapper to prevent memory leak
145 throw std::runtime_error("Can't import GraphDef");
146}

◆ ~Runner()

nnkit::support::tf::Runner::~Runner ( )

Definition at line 148 of file Runner.cpp.

149{
150 if (_graph)
151 TF_DeleteGraph(_graph);
152
153 if (_sess)
154 {
155 TF_CloseSession(_sess, _status);
156 TF_DeleteSession(_sess, _status);
157 }
158
159 for (auto tensor : _input_tensors)
160 TF_DeleteTensor(tensor);
161
162 for (auto tensor : _output_tensors)
163 TF_DeleteTensor(tensor);
164
165 TF_DeleteStatus(_status);
166}

Member Function Documentation

◆ getTensorDtypeFromGraphDef()

bool nnkit::support::tf::Runner::getTensorDtypeFromGraphDef ( const std::unique_ptr< ParsedTensor > &  tensor,
Runner::DataType dtype 
)

Get tensor data type from GraphDef.

Note
If the node cannot be found or dtype of the node is unknown, it returns false.

Definition at line 195 of file Runner.cpp.

197{
198 TF_Output tensor_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
199 tensor->tensorIndex()};
200
201 if (tensor_op.oper == nullptr)
202 return false;
203
204 TF_DataType tf_dtype = TF_OperationOutputType(tensor_op);
205
206 switch (tf_dtype)
207 {
208 case TF_DataType::TF_FLOAT:
209 dtype = DataType::FLOAT;
210 break;
211 case TF_DataType::TF_UINT8:
212 dtype = DataType::U8;
213 break;
214 case TF_DataType::TF_UINT16:
215 dtype = DataType::U16;
216 break;
217 case TF_DataType::TF_UINT32:
218 dtype = DataType::U32;
219 break;
220 case TF_DataType::TF_UINT64:
221 dtype = DataType::U64;
222 break;
223 case TF_DataType::TF_INT8:
224 dtype = DataType::S8;
225 break;
226 case TF_DataType::TF_INT16:
227 dtype = DataType::S16;
228 break;
229 case TF_DataType::TF_INT32:
230 dtype = DataType::S32;
231 break;
232 case TF_DataType::TF_INT64:
233 dtype = DataType::S64;
234 break;
235 default:
236 dtype = DataType::Unknown;
237 return false;
238 }
239 return true;
240}
DataType
Supported Data Types.

◆ getTensorShapeFromGraphDef()

bool nnkit::support::tf::Runner::getTensorShapeFromGraphDef ( const std::unique_ptr< ParsedTensor > &  tensor,
angkor::TensorShape shape 
)

Get tensor shape from GraphDef for input tensor only.

Note
If the node cannot be found or shape you provided is wrong or not enough though shape must be needed because of unknown shape in GraphDef, it returns false.

Definition at line 168 of file Runner.cpp.

170{
171 assert(!tensor->hasShape());
172 TF_Output tensor_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
173 tensor->tensorIndex()};
174
175 if (tensor_op.oper == nullptr)
176 return false;
177
178 int dim_size = TF_GraphGetTensorNumDims(_graph, tensor_op, _status);
179 if (dim_size == -1)
180 return false;
181 int64_t dims[dim_size];
182
183 TF_GraphGetTensorShape(_graph, tensor_op, dims, dim_size, _status);
184
185 shape.resize(dim_size);
186 for (int d = 0; d < dim_size; d++)
187 {
188 if (dims[d] == -1)
189 return false;
190 shape.dim(d) = dims[d];
191 }
192 return true;
193}
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36

References nncc::core::ADT::tensor::Shape::dim(), and nncc::core::ADT::tensor::Shape::resize().

Referenced by nnkit::support::tf::Backend::Backend().

◆ output()

const std::vector< TF_Tensor * > & nnkit::support::tf::Runner::output ( )
inline

Definition at line 86 of file Runner.h.

86{ return _output_tensors; }

Referenced by nnkit::support::tf::Backend::run().

◆ prepareInputs()

void nnkit::support::tf::Runner::prepareInputs ( const std::vector< std::unique_ptr< ParsedTensor > > &  inputs,
TensorDataMap data_map 
)

Definition at line 242 of file Runner.cpp.

244{
245 assert(_graph);
246
247 for (const auto &tensor : inputs)
248 {
249 TF_Output input_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
250 tensor->tensorIndex()};
251
252 if (input_op.oper == nullptr)
253 throw std::runtime_error("Can't init input_op : " + tensor->name());
254
255 std::vector<int64_t> shape;
256 for (int r = 0; r < tensor->shape().rank(); r++)
257 shape.emplace_back(tensor->shape().dim(r));
258
259 int size = 0;
260 if (tensor->isFloatTensor())
261 size = sizeof(float);
262 else
263 throw std::runtime_error("Not supported tensor type");
264
265 TF_Tensor *input_tensor =
266 create_tensor(TF_FLOAT, shape.data(), shape.size(), data_map.data(tensor.get()),
267 num_elements(tensor->shape()) * size);
268
269 _input_ops.emplace_back(input_op);
270 _input_tensors.emplace_back(input_tensor);
271 }
272}
#define TF_FLOAT
Definition Compat.h:25
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
int32_t size[5]
Definition Slice.cpp:35

References nnkit::support::tf::TensorDataMap::data(), size, and TF_FLOAT.

Referenced by nnkit::support::tf::Backend::prepare().

◆ prepareOutputs()

void nnkit::support::tf::Runner::prepareOutputs ( const std::vector< std::unique_ptr< ParsedTensor > > &  outputs)

Definition at line 274 of file Runner.cpp.

275{
276 assert(_graph);
277
278 for (const auto &tensor : outputs)
279 {
280 TF_Output output_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
281 tensor->tensorIndex()};
282
283 if (output_op.oper == nullptr)
284 throw std::runtime_error("Can't init output_op : " + tensor->name());
285
286 _output_ops.emplace_back(output_op);
287 }
288
289 _output_tensors.resize(_output_ops.size());
290}

Referenced by nnkit::support::tf::Backend::prepare().

◆ run()

void nnkit::support::tf::Runner::run ( )

Definition at line 292 of file Runner.cpp.

293{
294 assert(_graph);
295 assert(_output_ops.size() > 0);
296
297 TF_SessionOptions *options = TF_NewSessionOptions();
298 _sess = TF_NewSession(_graph, options, _status);
299 TF_DeleteSessionOptions(options);
300
301 if (TF_GetCode(_status) != TF_OK)
302 throw std::runtime_error(TF_Message(_status));
303
304 TF_SessionRun(_sess,
305 nullptr, // Run options.
306 _input_ops.data(), _input_tensors.data(), _input_ops.size(), _output_ops.data(),
307 _output_tensors.data(), _output_ops.size(), nullptr,
308 0, // Target operations, number of targets.
309 nullptr, // Run metadata.
310 _status // Output status.
311 );
312
313 if (TF_GetCode(_status) != TF_OK)
314 throw std::runtime_error(TF_Message(_status));
315
316 TF_CloseSession(_sess, _status);
317 TF_DeleteSession(_sess, _status);
318 _sess = nullptr;
319}

Referenced by package.infer.session::inference(), and nnkit::support::tf::Backend::run().


The documentation for this class was generated from the following files: