ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Session.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef _LOCOMOTIV_SESSION_H_
18#define _LOCOMOTIV_SESSION_H_
19
20#include "locomotiv/NodeData.h"
21
22#include <loco.h>
23
24#include <memory>
25#include <vector>
26
27namespace locomotiv
28{
29
33class Session final
34{
35public:
36 Session() = delete;
37
39 Session(loco::Graph *g) : Session(g, loco::output_nodes(g))
40 {
41 // DO NOTHING
42 }
43
53 Session(loco::Graph *g, const std::vector<loco::Node *> &custom_outputs)
54 : _graph(g), _outputs(custom_outputs)
55 {
56 // DO NOTHING
57 }
58
60 template <typename InputIt>
61 Session(loco::Graph *g, InputIt begin, InputIt end) : _graph(g), _outputs(begin, end)
62 {
63 // DO NOTHING
64 }
65
67 ~Session();
68
70 uint32_t input_size() const { return _graph->inputs()->size(); }
71
79 void set_input(uint32_t index, std::unique_ptr<NodeData> &&data);
80
87 void infer();
88
90 uint32_t output_size() const { return _outputs.size(); }
91
97 const NodeData *get_output(uint32_t index);
98
99 const loco::Node *get_output_node(uint32_t index) { return _outputs.at(index); }
100
101private:
102 loco::Graph *_graph;
103 std::vector<loco::Node *> _outputs;
104};
105
106} // namespace locomotiv
107
108#endif // _LOCOMOTIV_SESSION_H_
A neural network graph.
Definition Graph.h:161
InputContext * inputs(void)
Definition Graph.h:220
Logical unit of computation.
Definition Node.h:54
uint32_t size(void) const
Return the number of objects.
Definition ObjectPool.h:38
Session for loco graph inference.
Definition Session.h:34
const NodeData * get_output(uint32_t index)
Get output of graph as NodeData.
Definition Session.cpp:85
~Session()
Free all node annotations of the graph assigned by this Session.
Definition Session.cpp:30
void set_input(uint32_t index, std::unique_ptr< NodeData > &&data)
Set graph input at specific index by NodeData.
Definition Session.cpp:41
Session(loco::Graph *g, const std::vector< loco::Node * > &custom_outputs)
Make Session for graph with selective custom outputs. Only subgraph to calculate given outputs would ...
Definition Session.h:53
void infer()
Do inference for this session and graph.
Definition Session.cpp:75
const loco::Node * get_output_node(uint32_t index)
Definition Session.h:99
uint32_t output_size() const
Get number of graph outputs held by this Session.
Definition Session.h:90
Session(loco::Graph *g, InputIt begin, InputIt end)
Make Session by range.
Definition Session.h:61
uint32_t input_size() const
Get number of graph inputs held by this Session.
Definition Session.h:70
Session(loco::Graph *g)
Make Session for graph with graph outputs themselves.
Definition Session.h:39
int32_t begin[5]
Definition Slice.cpp:33
Read-only no-template wrapper for 'Buffer'. Serves interface for input and output of 'Session'.
Definition NodeData.h:36