ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TestHelper.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __MPQSOLVER_TEST_HELPER_H__
18#define __MPQSOLVER_TEST_HELPER_H__
19
20#include "DataProvider.h"
21
22#include <luci/IR/CircleNodes.h>
23#include <luci/IR/Module.h>
24#include <luci/test/TestIOGraph.h>
25
26namespace mpqsolver
27{
28namespace test
29{
30namespace models
31{
32
37{
38public:
39 SimpleGraph() : _g(loco::make_graph()) {}
40
41public:
42 void init();
43
44 virtual ~SimpleGraph() = default;
46
47protected:
49 virtual void initInput(loco::Node *){};
50
51public:
52 std::unique_ptr<loco::Graph> _g;
55 uint32_t _channel_size = 16;
56 uint32_t _width = 4;
57 uint32_t _height = 4;
58};
59
63class AddGraph final : public SimpleGraph
64{
65private:
66 void initInput(loco::Node *input) override;
67 void initMinMax(luci::CircleNode *node);
68
69 loco::Node *insertGraphBody(loco::Node *input) override;
70
71public:
72 float _a_min = -1.f;
73 float _a_max = 1.f;
76};
77
79{
80public:
81 SoftmaxGraphlet() = default;
82 virtual ~SoftmaxGraphlet() = default;
83
84 void init(loco::Graph *g);
85
86protected:
87 void initMinMax(luci::CircleNode *node, float min, float max);
88
89public:
96
97protected:
99};
100
101class SoftmaxTestGraph : public luci::test::TestIOGraph, public SoftmaxGraphlet
102{
103public:
104 SoftmaxTestGraph() = default;
105
106 void init(void);
107};
108
109} // namespace models
110
111namespace io_utils
112{
113
117void makeTemporaryFile(char *name_template);
118
122void writeDataToFile(const std::string &file_path, const std::string &data);
123
127std::string makeTemporaryFolder(char *name_template);
128
132bool isFileExists(const std::string &file_path);
133
134} // namespace io_utils
135
136namespace data_utils
137{
138
139std::unique_ptr<mpqsolver::core::DataProvider> getAllZeroSingleDataProvider();
140
141} // namespace data_utils
142
143} // namespace test
144} // namespace mpqsolver
145
146#endif //__MPQSOLVER_TEST_HELPER_H__
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
ABS in Circle.
Definition CircleAbs.h:32
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
DIV in Circle.
Definition CircleDiv.h:37
EXP in Circle.
Definition CircleExp.h:32
CircleNode used for Input of the Graph.
Definition CircleInput.h:36
CircleNode for Output of the Graph.
REDUCE_MAX in Circle.
SUB in Circle.
Definition CircleSub.h:34
SUM in Circle.
Definition CircleSum.h:32
Collection of 'loco::Graph's.
Definition Module.h:33
simple model with just an Add of input and constant
Definition TestHelper.h:64
base class of simple graphs used for testing
Definition TestHelper.h:37
std::unique_ptr< loco::Graph > _g
Definition TestHelper.h:52
virtual void initInput(loco::Node *)
Definition TestHelper.h:49
void transfer_to(luci::Module *module)
virtual loco::Node * insertGraphBody(loco::Node *input)=0
void initMinMax(luci::CircleNode *node, float min, float max)
std::unique_ptr< mpqsolver::core::DataProvider > getAllZeroSingleDataProvider()
std::string makeTemporaryFolder(char *name_template)
create valid name of temporary folder
void makeTemporaryFile(char *name_template)
create valid name of temporary file
bool isFileExists(const std::string &file_path)
checks whether file exists
void writeDataToFile(const std::string &file_path, const std::string &data)
write data to file_path