ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Support.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Support.hpp"
19
20#include <tensorflow/core/framework/graph.pb.h>
21
22#include <memory>
23#include <cassert>
24#include <fstream>
25#include <stdexcept>
26
27namespace
28{
29
30template <typename T>
31std::unique_ptr<T> open_fstream(const std::string &path, std::ios_base::openmode mode)
32{
33 if (path == "-")
34 {
35 return nullptr;
36 }
37
38 auto stream = std::make_unique<T>(path.c_str(), mode);
39 if (!stream->is_open())
40 {
41 throw std::runtime_error{"Failed to open " + path};
42 }
43 return stream;
44}
45
46} // namespace
47
48namespace tfkit
49{
50namespace tf
51{
52
53bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
54{
55 return node.attr().count(attr_name) > 0;
56}
57
58tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
59{
60 assert(HasAttr(node, attr_name));
61 const auto &attr = node.attr().at(attr_name);
62 assert(attr.value_case() == tensorflow::AttrValue::kType);
63 return attr.type();
64}
65
66tensorflow::TensorProto *GetTensorAttr(tensorflow::NodeDef &node, const std::string &attr_name)
67{
68 assert(HasAttr(node, attr_name));
69 tensorflow::AttrValue &attr = node.mutable_attr()->at(attr_name);
70 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
71 return attr.mutable_tensor();
72}
73
74int GetElementCount(const tensorflow::TensorShapeProto &shape)
75{
76 int count = -1;
77
78 for (auto &d : shape.dim())
79 {
80 if (d.size() == 0)
81 {
82 count = 0;
83 break;
84 }
85 if (count == -1)
86 count = 1;
87
88 count *= d.size();
89 }
90 return count;
91}
92
93} // namespace tf
94
95std::string CmdArguments::get(unsigned int index) const
96{
97 if (index >= _argc)
98 throw std::runtime_error("Argument index out of bound");
99
100 return std::string(_argv[index]);
101}
102
103std::string CmdArguments::get_or(unsigned int index, const std::string &s) const
104{
105 if (index >= _argc)
106 return s;
107
108 return std::string(_argv[index]);
109}
110
111std::unique_ptr<IOConfiguration> make_ioconfig(const CmdArguments &cmdargs)
112{
113 auto iocfg = std::make_unique<IOConfiguration>();
114
115 auto in = open_fstream<std::ifstream>(cmdargs.get_or(0, "-"), std::ios::in | std::ios::binary);
116 iocfg->in(std::move(in));
117
118 auto out = open_fstream<std::ofstream>(cmdargs.get_or(1, "-"), std::ios::out | std::ios::binary);
119 iocfg->out(std::move(out));
120
121 return iocfg;
122}
123
124} // namespace tfkit
std::string get_or(unsigned int index, const std::string &) const
Definition Support.cpp:103
std::string get(unsigned int index) const
Definition Support.cpp:95
bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Support.cpp:53
int GetElementCount(const tensorflow::TensorShapeProto &shape)
GetElementCount returns -1 for rank-0 tensor shape.
Definition Support.cpp:74
tensorflow::TensorProto * GetTensorAttr(tensorflow::NodeDef &node, const std::string &attr_name)
Definition Support.cpp:66
tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Support.cpp:58
std::unique_ptr< IOConfiguration > make_ioconfig(const CmdArguments &cmdargs)
Definition Support.cpp:111