ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GRU.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/GRU.h"
18
19#include "kernels/Utils.h"
20
21#include "PALFullyConnected.h"
22#include "PALGRU.h"
23
24namespace luci_interpreter
25{
26namespace kernels
27{
28GRU::GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias,
29 const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state,
30 Tensor *output, const GRUParams &params)
33 params)
34{
35}
36
38{
39 auto hidden_hidden_shape = getTensorShape(hidden_hidden());
40 auto hidden_input_shape = getTensorShape(hidden_input());
41 LUCI_INTERPRETER_CHECK(hidden_hidden_shape.Dims(0) == hidden_input_shape.Dims(0));
42
43 output()->resize(state()->shape());
44
45 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
46}
47
48void GRU::execute() const
49{
50 switch (input()->element_type())
51 {
52 case DataType::FLOAT32:
53 evalFloat();
54 break;
55 default:
56 throw std::runtime_error("luci-GRU Unsupported data type.");
57 }
58}
59
60void GRU::evalFloat() const
61{
62 uint8_t *output_hidden_data;
63 uint8_t *output_input_data;
64
65 // allocate output datas above
66 output_hidden_data = new uint8_t[getTensorShape(hidden_hidden()).FlatSize() * sizeof(float)];
67 output_input_data = new uint8_t[getTensorShape(hidden_input()).FlatSize() * sizeof(float)];
68
70 getTensorData<float>(input()), getTensorData<float>(hidden_input()),
71 getTensorData<float>(hidden_hidden()), getTensorData<float>(hidden_input_bias()),
72 getTensorData<float>(hidden_hidden_bias()), getTensorData<float>(state()),
73 getTensorData<float>(output()), reinterpret_cast<float *>(output_input_data),
74 reinterpret_cast<float *>(output_hidden_data), getTensorShape(input()),
76
77 delete output_hidden_data;
78 delete output_input_data;
79}
80
81} // namespace kernels
82} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
void configure() override
Definition GRU.cpp:37
const Tensor * input() const
Definition GRU.h:35
void execute() const override
Definition GRU.cpp:48
const Tensor * hidden_input_bias() const
Definition GRU.h:39
const Tensor * hidden_input() const
Definition GRU.h:38
Tensor * output() const
Definition GRU.h:41
const Tensor * hidden_hidden_bias() const
Definition GRU.h:37
const Tensor * hidden_hidden() const
Definition GRU.h:36
GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias, const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state, Tensor *output, const GRUParams &params)
Definition GRU.cpp:28
const Tensor * state() const
Definition GRU.h:40
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape)
Definition PALGRU.h:147