ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Session.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "locomotiv/Session.h"
18#include "locomotiv/NodeData.h"
19
20#include "UserData.h"
21#include "NodeDataImpl.h"
22#include "NodeExecution.h"
23#include "NodeDomain.h"
24
25#include <cassert>
26
27namespace locomotiv
28{
29
31{
32 for (uint32_t i = 0; i < _graph->nodes()->size(); ++i)
33 {
34 auto node = _graph->nodes()->at(i);
35 erase_user_data(node);
36 erase_annot_data(node);
38 }
39}
40
41void Session::set_input(uint32_t index, std::unique_ptr<NodeData> &&data)
42{
43 assert(index < input_size());
44
45 // Check whether already annotated
46 auto pull = loco::pull_node(_graph, index);
47 if (user_data(pull))
48 {
49 throw std::runtime_error("Graph input already has NodeData");
50 }
51
52 // Check data type match
53 if (pull->dtype() != data->dtype())
54 {
55 throw std::runtime_error("Data type mismatch");
56 }
57
58 // Check shape match
59 auto shape = data->shape();
60 if (pull->rank() != shape->rank())
61 {
62 throw std::runtime_error("Shape rank mismatch");
63 }
64 for (uint32_t i = 0; i < pull->rank(); ++i)
65 {
66 if (pull->dim(i).known() && pull->dim(i).value() != shape->dim(i))
67 {
68 throw std::runtime_error("Shape dimension mismatch");
69 }
70 }
71
72 user_data(pull, std::move(data));
73}
74
76{
77 auto schedules = loco::postorder_traversal(_outputs);
78
79 for (auto node : schedules)
80 {
81 NodeExecution::get().run(node);
82 }
83}
84
85const NodeData *Session::get_output(uint32_t index)
86{
87 assert(index < output_size());
88
89 auto output_node = _outputs.at(index);
90 return annot_data(output_node);
91}
92
93} // namespace locomotiv
NodeContext * nodes(void)
Definition Graph.h:218
T * at(uint32_t n) const
Access N-th object.
Definition ObjectPool.h:41
static NodeExecution & get()
void run(loco::Node *node)
Run calculation for one unspecified Node.
const NodeData * get_output(uint32_t index)
Get output of graph as NodeData.
Definition Session.cpp:85
~Session()
Free all node annotations of the graph assigned by this Session.
Definition Session.cpp:30
void set_input(uint32_t index, std::unique_ptr< NodeData > &&data)
Set graph input at specific index by NodeData.
Definition Session.cpp:41
void infer()
Do inference for this session and graph.
Definition Session.cpp:75
uint32_t output_size() const
Get number of graph outputs held by this Session.
Definition Session.h:90
uint32_t input_size() const
Get number of graph inputs held by this Session.
Definition Session.h:70
std::vector< loco::Node * > postorder_traversal(const std::vector< loco::Node * > &roots)
Generate postorder traversal sequence starting from "roots".
Definition Algorithm.cpp:53
Pull * pull_node(Graph *g, const GraphInputIndex &index)
Find a Pull node with a given input index.
Definition Nodes.cpp:162
void erase_annot_domain(loco::Node *node)
Erase already annotated node domain.
int32_t size[5]
Definition Slice.cpp:35
Read-only no-template wrapper for 'Buffer'. Serves interface for input and output of 'Session'.
Definition NodeData.h:36