46#include <tensorflow/c/c_api.h>
65TF_Tensor *create_tensor(
const TF_DataType data_type,
const std::int64_t *dims,
66 const std::size_t num_dims,
const void *data,
const std::size_t len)
68 if (dims ==
nullptr || data ==
nullptr)
71 TF_Tensor *
tensor = TF_AllocateTensor(data_type, dims,
static_cast<int>(num_dims), len);
72 if (tensor ==
nullptr)
75 void *tensor_data = TF_TensorData(tensor);
76 if (tensor_data ==
nullptr)
78 TF_DeleteTensor(tensor);
82 std::memcpy(tensor_data, data, std::min(len, TF_TensorByteSize(tensor)));
87void deallocate_buffer(
void *data,
size_t)
93TF_Buffer *build_TFBuffer(
const char *file)
95 const auto f = std::fopen(file,
"rb");
98 throw std::runtime_error(std::string(
"cannot open ") + file);
100 std::fseek(f, 0, SEEK_END);
101 const auto fsize = ftell(f);
103 std::fseek(f, 0, SEEK_SET);
108 throw std::runtime_error(std::string(
"file read error: ") + file);
111 const auto data = std::malloc(fsize);
112 std::fread(data, fsize, 1, f);
115 TF_Buffer *
buf = TF_NewBuffer();
118 buf->data_deallocator = deallocate_buffer;
129 _graph = TF_NewGraph();
130 _status = TF_NewStatus();
133 TF_Buffer *buffer = build_TFBuffer(pb_path);
134 if (buffer ==
nullptr)
135 throw std::runtime_error(
"Can't read buffer from file");
137 TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
139 TF_GraphImportGraphDef(_graph, buffer, opts, _status);
141 TF_DeleteImportGraphDefOptions(opts);
142 TF_DeleteBuffer(buffer);
144 if (TF_GetCode(_status) != TF_OK)
145 throw std::runtime_error(
"Can't import GraphDef");
151 TF_DeleteGraph(_graph);
155 TF_CloseSession(_sess, _status);
156 TF_DeleteSession(_sess, _status);
159 for (
auto tensor : _input_tensors)
160 TF_DeleteTensor(tensor);
162 for (
auto tensor : _output_tensors)
163 TF_DeleteTensor(tensor);
165 TF_DeleteStatus(_status);
171 assert(!tensor->hasShape());
172 TF_Output tensor_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
173 tensor->tensorIndex()};
175 if (tensor_op.oper ==
nullptr)
178 int dim_size = TF_GraphGetTensorNumDims(_graph, tensor_op, _status);
181 int64_t dims[dim_size];
183 TF_GraphGetTensorShape(_graph, tensor_op, dims, dim_size, _status);
186 for (
int d = 0; d < dim_size; d++)
190 shape.
dim(d) = dims[d];
198 TF_Output tensor_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
199 tensor->tensorIndex()};
201 if (tensor_op.oper ==
nullptr)
204 TF_DataType tf_dtype = TF_OperationOutputType(tensor_op);
208 case TF_DataType::TF_FLOAT:
209 dtype = DataType::FLOAT;
211 case TF_DataType::TF_UINT8:
212 dtype = DataType::U8;
214 case TF_DataType::TF_UINT16:
215 dtype = DataType::U16;
217 case TF_DataType::TF_UINT32:
218 dtype = DataType::U32;
220 case TF_DataType::TF_UINT64:
221 dtype = DataType::U64;
223 case TF_DataType::TF_INT8:
224 dtype = DataType::S8;
226 case TF_DataType::TF_INT16:
227 dtype = DataType::S16;
229 case TF_DataType::TF_INT32:
230 dtype = DataType::S32;
232 case TF_DataType::TF_INT64:
233 dtype = DataType::S64;
236 dtype = DataType::Unknown;
247 for (
const auto &tensor : inputs)
249 TF_Output input_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
250 tensor->tensorIndex()};
252 if (input_op.oper ==
nullptr)
253 throw std::runtime_error(
"Can't init input_op : " + tensor->name());
255 std::vector<int64_t> shape;
256 for (
int r = 0; r < tensor->shape().rank(); r++)
257 shape.emplace_back(tensor->shape().dim(r));
260 if (tensor->isFloatTensor())
261 size =
sizeof(float);
263 throw std::runtime_error(
"Not supported tensor type");
265 TF_Tensor *input_tensor =
266 create_tensor(
TF_FLOAT, shape.data(), shape.size(), data_map.
data(tensor.get()),
267 num_elements(tensor->shape()) *
size);
269 _input_ops.emplace_back(input_op);
270 _input_tensors.emplace_back(input_tensor);
278 for (
const auto &tensor : outputs)
280 TF_Output output_op = {TF_GraphOperationByName(_graph, tensor->nodeName().c_str()),
281 tensor->tensorIndex()};
283 if (output_op.oper ==
nullptr)
284 throw std::runtime_error(
"Can't init output_op : " + tensor->name());
286 _output_ops.emplace_back(output_op);
289 _output_tensors.resize(_output_ops.size());
295 assert(_output_ops.size() > 0);
297 TF_SessionOptions *options = TF_NewSessionOptions();
298 _sess = TF_NewSession(_graph, options, _status);
299 TF_DeleteSessionOptions(options);
301 if (TF_GetCode(_status) != TF_OK)
302 throw std::runtime_error(TF_Message(_status));
306 _input_ops.data(), _input_tensors.data(), _input_ops.size(), _output_ops.data(),
307 _output_tensors.data(), _output_ops.size(),
nullptr,
313 if (TF_GetCode(_status) != TF_OK)
314 throw std::runtime_error(TF_Message(_status));
316 TF_CloseSession(_sess, _status);
317 TF_DeleteSession(_sess, _status);
uint32_t & dim(uint32_t axis)
Shape & resize(uint32_t size)
void prepareInputs(const std::vector< std::unique_ptr< ParsedTensor > > &inputs, TensorDataMap &data_map)
void prepareOutputs(const std::vector< std::unique_ptr< ParsedTensor > > &outputs)
bool getTensorShapeFromGraphDef(const std::unique_ptr< ParsedTensor > &tensor, angkor::TensorShape &shape)
Get tensor shape from GraphDef for input tensor only.
bool getTensorDtypeFromGraphDef(const std::unique_ptr< ParsedTensor > &tensor, Runner::DataType &dtype)
Get tensor data type from GraphDef.
Runner(const char *pb_path)
Class to map parsed tensor and memory for tensor values. For parsed tensor, this memory is used to fi...
uint8_t * data(const ParsedTensor *parsed_tensor)
Class to store tensor information parsed from test.info file under moco/test/tf.
uint64_t num_elements(const Shape &)
DataType
Supported Data Types.