ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Runner.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNKIT_SUPPORT_TF_RUNNER_H__
18#define __NNKIT_SUPPORT_TF_RUNNER_H__
19
22#include <angkor/TensorShape.h>
23
24#include <tensorflow/c/c_api.h>
25
26#include <vector>
27
28namespace nnkit
29{
30namespace support
31{
32namespace tf
33{
34
36
37class Runner final
38{
39public:
40 enum class DataType
41 {
42 Unknown, // Unknown type (serves as a default value)
43
44 U8, // 8-bit unsigned integer
45 U16, // 16-bit unsigned integer
46 U32, // 32-bit unsigned integer
47 U64, // 64-bit unsigned integer
48
49 S8, // 8-bit signed integer
50 S16, // 16-bit signed integer
51 S32, // 32-bit signed integer
52 S64, // 64-bit signed integer
53
54 FLOAT, // floating-point
55 };
56
57public:
58 Runner(const char *pb_path);
59
60 ~Runner();
61
68 bool getTensorShapeFromGraphDef(const std::unique_ptr<ParsedTensor> &tensor,
69 angkor::TensorShape &shape);
70
76 bool getTensorDtypeFromGraphDef(const std::unique_ptr<ParsedTensor> &tensor,
77 Runner::DataType &dtype);
78
79 void prepareInputs(const std::vector<std::unique_ptr<ParsedTensor>> &inputs,
80 TensorDataMap &data_map);
81
82 void prepareOutputs(const std::vector<std::unique_ptr<ParsedTensor>> &outputs);
83
84 void run();
85
86 const std::vector<TF_Tensor *> &output() { return _output_tensors; }
87
88private:
89 TF_Graph *_graph;
90 TF_Session *_sess;
91
92 std::vector<TF_Output> _input_ops;
93 std::vector<TF_Tensor *> _input_tensors;
94
95 std::vector<TF_Output> _output_ops;
96 std::vector<TF_Tensor *> _output_tensors;
97
98 TF_Status *_status;
99};
100
101} // namespace tf
102} // namespace support
103} // namespace nnkit
104
105#endif // __NNKIT_SUPPORT_TF_RUNNER_H__
void prepareInputs(const std::vector< std::unique_ptr< ParsedTensor > > &inputs, TensorDataMap &data_map)
Definition Runner.cpp:242
void prepareOutputs(const std::vector< std::unique_ptr< ParsedTensor > > &outputs)
Definition Runner.cpp:274
const std::vector< TF_Tensor * > & output()
Definition Runner.h:86
bool getTensorShapeFromGraphDef(const std::unique_ptr< ParsedTensor > &tensor, angkor::TensorShape &shape)
Get tensor shape from GraphDef for input tensor only.
Definition Runner.cpp:168
bool getTensorDtypeFromGraphDef(const std::unique_ptr< ParsedTensor > &tensor, Runner::DataType &dtype)
Get tensor data type from GraphDef.
Definition Runner.cpp:195
Class to map parsed tensor and memory for tensor values. For parsed tensor, this memory is used to fi...
Class to store tensor information parsed from test.info file under moco/test/tf.