ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ModelAnalyzer.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ModelAnalyzer.h"
18#include "mir/Graph.h"
19#include "mir/ops/InputOp.h"
20#include "mir/ops/ReluOp.h"
21#include "mir/ops/ConcatOp.h"
22
23#include <gtest/gtest.h>
24
25#include <algorithm>
26
27using namespace std;
28using namespace nnc;
29using namespace mir;
30using namespace sir;
31
32static const CallFunction *getCall(const unique_ptr<Action> &ptr)
33{
34 return dynamic_cast<const CallFunction *>(ptr.get());
35}
36
37/*
38 * This test designed to check basic layout properties of Model analyzer
39 */
40TEST(ModelAnalyzer, linearization)
41{
42 mir::Graph g;
43 /*
44 * Create graph:
45 * [input]
46 * / \
47 * | |
48 * V V
49 * [head1] [head2]
50 * | |
51 * V V
52 * [tail1] [tail2]
53 * \ /
54 * \ /
55 * [join]
56 */
57 mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
58 Operation *input = g.create<ops::InputOp>(input_type);
59 Operation *head1 = g.create<ops::ReluOp>(input->getOutput(0));
60 Operation *head2 = g.create<ops::ReluOp>(input->getOutput(0));
61 Operation *tail1 = g.create<ops::ReluOp>(head1->getOutput(0));
62 Operation *tail2 = g.create<ops::ReluOp>(head2->getOutput(0));
63 vector<mir::Operation::Output *> concat_inputs{tail1->getOutput(0), tail2->getOutput(0)};
64 Operation *join = g.create<ops::ConcatOp>(concat_inputs, 0);
65 input->getOutput(0)->setName("input");
66 head1->getOutput(0)->setName("head1");
67 head2->getOutput(0)->setName("head2");
68 tail1->getOutput(0)->setName("tail2");
69 tail2->getOutput(0)->setName("tail2");
70 join->getOutput(0)->setName("join");
71
72 // Check that layout is desired
74 ma.analyze(&g);
75 const auto &seq = ma.getInferenceSequence();
76 ASSERT_EQ(seq.size(), 6u);
77
78 vector<Operation *> op_seq(seq.size());
79 transform(seq.cbegin(), seq.cend(), op_seq.begin(),
80 [](const unique_ptr<sir::Action> &action) { return getCall(action)->mirOp; });
81
82 vector<Operation *> valid_seq1{input, head1, tail1, head2, tail2, join};
83 vector<Operation *> valid_seq2{input, head2, tail2, head1, tail1, join};
84 ASSERT_TRUE(op_seq == valid_seq1 || op_seq == valid_seq2);
85}
Output * getOutput(std::size_t index)
Definition Operation.h:149
Description of tensor concatenation operation.
Definition ConcatOp.h:31
Constructs inference sequence for given computational graph, gathers list of variables used in artifa...
const std::vector< std::unique_ptr< sir::Action > > & getInferenceSequence() const
void analyze(const mir::Graph *g)
contructs inference sequence
Definition Shape.h:28
TEST(ModelAnalyzer, linearization)