ONE - On-device Neural Engine
Loading...
Searching...
No Matches
CircleCast.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
20
21#include <luci/UserSettings.h>
22#include <luci/Log.h>
23
24#include <loco.h>
25
26namespace luci
27{
28
30{
31 LOGGER(l);
32
33 if (!GraphBuilder::validate(args, 1))
34 return false;
35
36 auto settings = luci::UserSettings::settings();
37
38 const auto &inputs = args.op.inputs;
39 const auto &outputs = args.op.outputs;
40
41 // NOTE real models do have type mismatch
42 const auto *options = args.op.builtin_options.AsCastOptions();
43 if (options != nullptr)
44 {
45 const auto tensors = args.reader.tensors();
46 const auto output_tensor = tensors[outputs[0]];
47 assert(output_tensor != nullptr);
48 auto name = tensor_name(output_tensor);
49
50 const auto tensor_in = tensors.at(inputs.at(0));
51 assert(tensor_in != nullptr);
52 if (tensor_in->type() != options->in_data_type)
53 {
55 {
56 WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
57 }
58 else
59 return false;
60 }
61 const auto &tensor_out = tensors.at(outputs[0]);
62 if (tensor_out->type() != options->out_data_type)
63 {
65 {
66 WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
67 }
68 else
69 return false;
70 }
71 }
72
73 return true;
74}
75
76CircleNode *CircleCastGraphBuilder::build_node(const circle::OperatorT &op,
77 const std::vector<CircleNode *> &inputs,
78 loco::Graph *graph) const
79{
80 auto *node = graph->nodes()->create<CircleCast>();
81 node->x(inputs.at(0));
82
83 const auto *options = op.builtin_options.AsCastOptions();
84 if (options != nullptr)
85 {
86 node->in_data_type(luci_datatype(options->in_data_type));
87 node->out_data_type(luci_datatype(options->out_data_type));
88 }
89 else
90 {
91 node->in_data_type(inputs.at(0)->dtype());
92 node->out_data_type(loco::DataType::Unknown);
93 // type inference should use node->dtype() for Unknown
94 // export should use BuiltinOptions_NONE for Unknown
95 }
96
97 return node;
98}
99
100} // namespace luci
#define LOGGER(name)
Definition Log.h:65
A neural network graph.
Definition Graph.h:161
bool validate(const ValidateArgs &args) const final
CAST in Circle.
Definition CircleCast.h:32
loco::Node * x(void) const
Definition CircleCast.h:34
bool validate(const ValidateArgs &args, size_t input_cnt) const
#define WARN(name)
Definition Log.h:70
loco::DataType luci_datatype(circle::TensorType type)
const char * tensor_name(const circle::Tensor *tensor)
static UserSettings * settings()