ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ModelAnalyzer.cpp File Reference
#include "ModelAnalyzer.h"
#include "mir/Graph.h"
#include "mir/ops/InputOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ConcatOp.h"
#include <gtest/gtest.h>
#include <algorithm>

Go to the source code of this file.

Functions

 TEST (ModelAnalyzer, linearization)
 

Function Documentation

◆ TEST()

TEST ( ModelAnalyzer  ,
linearization   
)

Definition at line 40 of file ModelAnalyzer.cpp.

41{
43 /*
44 * Create graph:
45 * [input]
46 * / \
47 * | |
48 * V V
49 * [head1] [head2]
50 * | |
51 * V V
52 * [tail1] [tail2]
53 * \ /
54 * \ /
55 * [join]
56 */
57 mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
58 Operation *input = g.create<ops::InputOp>(input_type);
59 Operation *head1 = g.create<ops::ReluOp>(input->getOutput(0));
60 Operation *head2 = g.create<ops::ReluOp>(input->getOutput(0));
61 Operation *tail1 = g.create<ops::ReluOp>(head1->getOutput(0));
62 Operation *tail2 = g.create<ops::ReluOp>(head2->getOutput(0));
63 vector<mir::Operation::Output *> concat_inputs{tail1->getOutput(0), tail2->getOutput(0)};
64 Operation *join = g.create<ops::ConcatOp>(concat_inputs, 0);
65 input->getOutput(0)->setName("input");
66 head1->getOutput(0)->setName("head1");
67 head2->getOutput(0)->setName("head2");
68 tail1->getOutput(0)->setName("tail2");
69 tail2->getOutput(0)->setName("tail2");
70 join->getOutput(0)->setName("join");
71
72 // Check that layout is desired
74 ma.analyze(&g);
75 const auto &seq = ma.getInferenceSequence();
76 ASSERT_EQ(seq.size(), 6u);
77
78 vector<Operation *> op_seq(seq.size());
79 transform(seq.cbegin(), seq.cend(), op_seq.begin(),
80 [](const unique_ptr<sir::Action> &action) { return getCall(action)->mirOp; });
81
82 vector<Operation *> valid_seq1{input, head1, tail1, head2, tail2, join};
83 vector<Operation *> valid_seq2{input, head2, tail2, head1, tail1, join};
84 ASSERT_TRUE(op_seq == valid_seq1 || op_seq == valid_seq2);
85}
Output * getOutput(std::size_t index)
Definition Operation.h:149
Description of tensor concatenation operation.
Definition ConcatOp.h:31
Constructs inference sequence for given computational graph, gathers list of variables used in artifa...
const std::vector< std::unique_ptr< sir::Action > > & getInferenceSequence() const
void analyze(const mir::Graph *g)
contructs inference sequence
std::string join(const std::string &path1, const std::string &path2)
Definition Shape.h:28

References nnc::ModelAnalyzer::analyze(), nnc::ModelAnalyzer::getInferenceSequence(), and mir::Operation::getOutput().