90 const circle::Tensor *input;
91 const circle::Tensor *weights_feature;
92 const circle::Tensor *weights_time;
93 const circle::Tensor *bias;
94 const circle::Tensor *activation_state;
96 const circle::Tensor *output;
99 uint8_t *weights_feature_data;
100 uint8_t *weights_time_data;
102 uint8_t *activation_state_data;
103 uint8_t *output_data;
104 const circle::SVDFOptions *options =
nullptr;
112 input = runtime_kernel.
inputs[inputTensorIdx];
113 weights_feature = runtime_kernel.
inputs[weightsFeatureTensorIdx];
114 weights_time = runtime_kernel.
inputs[weightsTimeTensorIdx];
115 bias = runtime_kernel.
inputs[biasTensorIdx];
116 activation_state = runtime_kernel.
inputs[inputActivationStateTensorIdx];
118 output = runtime_kernel.
outputs[outputTensorIdx];
120 assert(input !=
nullptr);
121 assert(weights_feature !=
nullptr);
122 assert(weights_time !=
nullptr);
124 assert(activation_state !=
nullptr);
125 assert(output !=
nullptr);
127 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
131 input_data = runtime_kernel.
inputs_data[inputTensorIdx];
132 weights_feature_data = runtime_kernel.
inputs_data[weightsFeatureTensorIdx];
133 weights_time_data = runtime_kernel.
inputs_data[weightsTimeTensorIdx];
134 bias_data = runtime_kernel.
inputs_data[biasTensorIdx];
135 activation_state_data = runtime_kernel.
inputs_data[inputActivationStateTensorIdx];
136 output_data = runtime_kernel.
outputs_data[outputTensorIdx];
138 assert(input_data !=
nullptr);
139 assert(weights_feature_data !=
nullptr);
140 assert(weights_time_data !=
nullptr);
142 assert(output_data !=
nullptr);
144 options = runtime_kernel.
first_operator->builtin_options_as_SVDFOptions();
155 const int rank = options->rank();
156 const int input_size = input_shape.
dims(1);
157 const int batch_size = input_shape.
dims(0);
158 const int num_filters = weights_feature_shape.
dims(0);
160 const int num_units = num_filters / rank;
161 const int memory_size = weights_time_shape.
dims(1);
163 const auto activation_state_size =
170 std::memset(activation_state_data, 0, activation_state_size);
172 switch (input->type())
175 case circle::TensorType_FLOAT32:
178 uint8_t *scratch_buffer;
180 batch_size * num_filters *
sizeof(
core::OMDataType(output->type())), &scratch_buffer);
182 assert(status ==
Ok);
186 utils::castInputData<float>(input_data), utils::castInputData<float>(weights_feature_data),
187 utils::castInputData<float>(weights_time_data), utils::castInputData<float>(bias_data),
188 utils::castOutputData<float>(activation_state_data),
189 utils::castOutputData<float>(scratch_buffer), utils::castOutputData<float>(output_data),
190 rank, input_size, batch_size, num_filters, num_units, memory_size,
191 options->fused_activation_function());
198 case circle::TensorType_INT8:
201 prepareQuantParams(params, input, weights_feature, weights_time, activation_state, output);
206 params, utils::castInputData<int8_t>(input_data),
207 utils::castInputData<int8_t>(weights_feature_data),
208 utils::castInputData<int8_t>(weights_time_data), utils::castInputData<int32_t>(bias_data),
209 utils::castOutputData<int8_t>(activation_state_data),
210 utils::castOutputData<int8_t>(output_data), input_shape, weights_feature_shape,
218 assert(
false &&
"Unsupported type.");
OMStatus SVDF(const core::SVDFQuantParams ¶ms, const int8_t *input_data, const int8_t *weights_feature_data, const int8_t *weights_time_data, const int32_t *bias_data, int8_t *state_data, int8_t *output_data, const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &weights_feature_shape, const core::OMRuntimeShape &weights_time_shape, const core::OMRuntimeShape &bias_shape, const core::OMRuntimeShape &output_shape)