ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Backend.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
24
25#include <angkor/TensorShape.h>
26
27#include <nnkit/Backend.h>
28
29#include <cstring> // memcpy
30
31namespace nnkit
32{
33namespace support
34{
35namespace tf
36{
37
39
40Backend::Backend(const char *pb_path, const char *info_path) : _tf_runner(pb_path)
41{
42 auto parsed_tensors = nnkit::support::tftestinfo::parse(info_path);
43 for (auto &parsed_tensor : parsed_tensors)
44 {
45 if (parsed_tensor->kind() == ParsedTensor::Kind::Input)
46 {
47 // user didn't specify input
48 if (!parsed_tensor->hasShape())
49 {
51 if (!_tf_runner.getTensorShapeFromGraphDef(parsed_tensor, shape))
52 throw oops::UserExn(
53 "Info you provided may be wrong or not enough. Please check the info file.");
54
55 parsed_tensor->mutable_shape().resize(shape.rank());
56 for (int r = 0; r < shape.rank(); r++)
57 {
58 parsed_tensor->mutable_shape().dim(r) = shape.dim(r);
59 }
60 }
61 _inputs.emplace_back(std::move(parsed_tensor));
62 }
63 else
64 _outputs.emplace_back(std::move(parsed_tensor));
65 }
66}
67
68void Backend::prepare(const std::function<void(nnkit::TensorContext &)> &f)
69{
70 for (const auto &input_tensor : _inputs)
71 _data_map.allocate(input_tensor.get());
72
73 TensorContext ctx(_inputs, _data_map);
74 f(ctx); // fill values
75
76 _tf_runner.prepareInputs(_inputs, _data_map);
77 _tf_runner.prepareOutputs(_outputs);
78}
79
80void Backend::run(void)
81{
82 _tf_runner.run();
83
84 // get result
85 auto actual_outputs = _tf_runner.output();
86
87 for (int n = 0; n < _outputs.size(); n++)
88 {
89 auto actual = actual_outputs[n];
90 const size_t byte_size = TF_TensorByteSize(actual);
91 const uint8_t *tf_data = reinterpret_cast<const uint8_t *>(TF_TensorData(actual));
92
93 const uint32_t shape_rank = TF_NumDims(actual);
94 _outputs[n]->mutable_shape().resize(shape_rank);
95 for (uint32_t r = 0; r < shape_rank; r++)
96 {
97 _outputs[n]->mutable_shape().dim(r) = TF_Dim(actual, r);
98 }
99 uint8_t *dest = _data_map.allocate(_outputs[n].get());
100
101 std::memcpy(dest, tf_data, byte_size);
102 }
103}
104
105void Backend::teardown(const std::function<void(nnkit::TensorContext &)> &f)
106{
107 TensorContext ctx(_outputs, _data_map);
108 f(ctx);
109}
110
111} // namespace tf
112} // namespace support
113} // namespace nnkit
This file contains functions to parse test.info files in moco/test/tf.
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
uint32_t rank(void) const
Definition Shape.cpp:35
void teardown(const std::function< void(nnkit::TensorContext &)> &f) override
Definition Backend.cpp:105
void prepare(const std::function< void(nnkit::TensorContext &)> &f) override
Definition Backend.cpp:68
void run(void) override
Definition Backend.cpp:80
void prepareInputs(const std::vector< std::unique_ptr< ParsedTensor > > &inputs, TensorDataMap &data_map)
Definition Runner.cpp:242
void prepareOutputs(const std::vector< std::unique_ptr< ParsedTensor > > &outputs)
Definition Runner.cpp:274
const std::vector< TF_Tensor * > & output()
Definition Runner.h:86
bool getTensorShapeFromGraphDef(const std::unique_ptr< ParsedTensor > &tensor, angkor::TensorShape &shape)
Get tensor shape from GraphDef for input tensor only.
Definition Runner.cpp:168
uint8_t * allocate(const ParsedTensor *parsed_tensor)
Class to store tensor information parsed from test.info file under moco/test/tf.
Exception to user.
Definition UserExn.h:42
std::vector< std::unique_ptr< ParsedTensor > > parse(const char *info_path)
Function to parse test.info.