19#include "kernels/Utils.h"
21#include "PALFullyConnected.h"
50 switch (
input()->element_type())
52 case DataType::FLOAT32:
56 throw std::runtime_error(
"luci-GRU Unsupported data type.");
60void GRU::evalFloat()
const
62 uint8_t *output_hidden_data;
63 uint8_t *output_input_data;
73 getTensorData<float>(
output()),
reinterpret_cast<float *
>(output_input_data),
77 delete output_hidden_data;
78 delete output_input_data;
void resize(const Shape &new_shape)
void configure() override
const Tensor * input() const
void execute() const override
const Tensor * hidden_input_bias() const
const Tensor * hidden_input() const
const Tensor * hidden_hidden_bias() const
const Tensor * hidden_hidden() const
GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias, const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state, Tensor *output, const GRUParams ¶ms)
const Tensor * state() const
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape)