20#include "../KernelGenerator.h"
21#include "../Validator.h"
26void Validator::visit(
const ir::operation::FullyConnected &node)
28 using ir::operation::FullyConnected;
30 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
35 if (weight_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q4_0 &&
36 weight_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q8_0)
45void KernelGenerator::visit(
const ir::operation::FullyConnected &node)
47 using ir::operation::FullyConnected;
49 const auto output_index{node.getOutputs().at(0)};
50 const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
51 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
52 const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
53 const auto activation = node.param().activation;
54 const auto weights_format = node.param().weights_format;
56 throw std::runtime_error(
"Unsupported FullyConnected Weights Format");
58 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
59 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
60 auto weight_tensor = _tensor_reg->getPortableTensor(weight_index);
61 auto bias_tensor = bias_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(bias_index);
63 auto fn = std::make_unique<ops::FullyConnectedLayer>();
65 fn->configure(input_tensor, weight_tensor,
bias_tensor, activation, output_tensor,
77 : _input(nullptr), _weights(nullptr), _bias(nullptr), _output(nullptr),
78 _activation(ir::Activation::
NONE), _external_context(nullptr)
88 throw std::runtime_error{
"FullyConnected: GGML weights format does not support bias yet."};
95 output.op = GGML_OP_MUL_MAT;
96 output.src[0] = &weights;
97 output.src[1] = &input;
99 auto *nodes = &output;
102 struct ggml_cgraph graph;
104 memset(&graph, 0,
sizeof(graph));
106 graph.nodes = &nodes;
111 std::vector<uint8_t> buf(cplan.work_size);
112 cplan.work_data = buf.data();
115 ggml_graph_compute(&graph, &cplan);
121 const std::shared_ptr<ExternalContext> &external_context)
140 throw std::runtime_error{
"FullyConnected: unsupported data type"};
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::unique_ptr< exec::IFunction > _return_fn
std::shared_ptr< ExternalContext > _external_context
const IPortableTensor * _weights
const IPortableTensor * _input
void fullyConnectedGGMLWeight()
void configure(const IPortableTensor *input, const IPortableTensor *weights, const IPortableTensor *bias, ir::Activation activation, IPortableTensor *output, const std::shared_ptr< ExternalContext > &external_context)
ir::Activation _activation
IPortableTensor * _output
const IPortableTensor * _bias
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.
struct ggml_tensor getGGMLTensor(const IPortableTensor *tensor)