ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::CircleCastGraphBuilder Class Reference

#include <CircleCast.h>

Collaboration diagram for luci::CircleCastGraphBuilder:

Public Member Functions

bool validate (const ValidateArgs &args) const final
 
- Public Member Functions inherited from luci::GraphBuilder
virtual ~GraphBuilder ()=default
 
bool validate (const ValidateArgs &args, size_t input_cnt) const
 
CircleNodebuild (const circle::OperatorT &op, GraphBuilderContext *context) const final
 
- Public Member Functions inherited from luci::GraphBuilderBase
virtual ~GraphBuilderBase ()=default
 

Detailed Description

Definition at line 25 of file CircleCast.h.

Member Function Documentation

◆ validate()

bool luci::CircleCastGraphBuilder::validate ( const ValidateArgs args) const
finalvirtual

Implements luci::GraphBuilderBase.

Definition at line 29 of file CircleCast.cpp.

30{
31 LOGGER(l);
32
33 if (!GraphBuilder::validate(args, 1))
34 return false;
35
36 auto settings = luci::UserSettings::settings();
37
38 const auto &inputs = args.op.inputs;
39 const auto &outputs = args.op.outputs;
40
41 // NOTE real models do have type mismatch
42 const auto *options = args.op.builtin_options.AsCastOptions();
43 if (options != nullptr)
44 {
45 const auto tensors = args.reader.tensors();
46 const auto output_tensor = tensors[outputs[0]];
47 assert(output_tensor != nullptr);
48 auto name = tensor_name(output_tensor);
49
50 const auto tensor_in = tensors.at(inputs.at(0));
51 assert(tensor_in != nullptr);
52 if (tensor_in->type() != options->in_data_type)
53 {
55 {
56 WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
57 }
58 else
59 return false;
60 }
61 const auto &tensor_out = tensors.at(outputs[0]);
62 if (tensor_out->type() != options->out_data_type)
63 {
65 {
66 WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
67 }
68 else
69 return false;
70 }
71 }
72
73 return true;
74}
#define LOGGER(name)
Definition Log.h:65
bool validate(const ValidateArgs &args, size_t input_cnt) const
#define WARN(name)
Definition Log.h:70
args
Definition infer.py:21
const char * tensor_name(const circle::Tensor *tensor)
static UserSettings * settings()

References luci::UserSettings::DisableValidation, LOGGER, luci::UserSettings::settings(), luci::tensor_name(), luci::GraphBuilder::validate(), and WARN.


The documentation for this class was generated from the following files: