46 const auto input_index = cur_op->inputs()->operator[](kSvdfInputTensor);
47 const auto weights_feature_index = cur_op->inputs()->operator[](kSvdfWeightsFeatureTensor);
48 const auto weights_time_index = cur_op->inputs()->operator[](kSvdfWeightsTimeTensor);
49 const auto bias_index = cur_op->inputs()->operator[](kSvdfBiasTensor);
50 const auto activation_state_index = cur_op->inputs()->operator[](kSvdfInputActivationStateTensor);
51 const auto output_index = cur_op->outputs()->operator[](kSvdfOutputTensor);
53 assert(input_index != -1);
54 assert(weights_feature_index != -1);
55 assert(weights_time_index != -1);
56 assert(activation_state_index != -1);
57 assert(output_index != -1);
66 assert(input !=
nullptr);
67 assert(weights_feature !=
nullptr);
68 assert(weights_time !=
nullptr);
69 assert(activation_state !=
nullptr);
70 assert(output !=
nullptr);
72 const auto *options = cur_op->builtin_options_as_SVDFOptions();
75 const int rank = options->rank();
78 const int num_filters =
Tensor::dim(weights_feature, 0);
81 const int num_units = num_filters / rank;
82 const int memory_size =
Tensor::dim(weights_time, 1);
85 Tensor::element_type(input) == DataType::S8);
114 if (Tensor::element_type(input) == DataType::FLOAT32)
127 const auto input_index = cur_op->inputs()->operator[](kSvdfInputTensor);
128 const auto weights_feature_index = cur_op->inputs()->operator[](kSvdfWeightsFeatureTensor);
129 const auto weights_time_index = cur_op->inputs()->operator[](kSvdfWeightsTimeTensor);
130 const auto bias_index = cur_op->inputs()->operator[](kSvdfBiasTensor);
131 const auto activation_state_index = cur_op->inputs()->operator[](kSvdfInputActivationStateTensor);
132 const auto output_index = cur_op->outputs()->operator[](kSvdfOutputTensor);
134 assert(input_index != -1);
135 assert(weights_feature_index != -1);
136 assert(weights_time_index != -1);
137 assert(activation_state_index != -1);
138 assert(output_index != -1);
147 assert(input !=
nullptr);
148 assert(weights_feature !=
nullptr);
149 assert(weights_time !=
nullptr);
150 assert(activation_state !=
nullptr);
151 assert(output !=
nullptr);
153 const auto *options = cur_op->builtin_options_as_SVDFOptions();
156 const int rank = options->rank();
159 const int num_filters =
Tensor::dim(weights_feature, 0);
162 const int num_units = num_filters / rank;
163 const int memory_size =
Tensor::dim(weights_time, 1);
171 const auto type = Tensor::element_type(input);
175 case DataType::FLOAT32:
178 auto state_data = std::make_unique<float[]>(Tensor::num_elements(activation_state));
179 std::fill_n(state_data.get(), Tensor::num_elements(activation_state), 0);
181 auto scratch_data = std::make_unique<uint8_t[]>(batch_size * num_filters *
sizeof(
float));
184 kernels::getTensorData<float>(input_data),
185 kernels::getTensorData<float>(weights_feature_data),
186 kernels::getTensorData<float>(weights_time_data), kernels::getTensorData<float>(bias_data),
187 state_data.get(), kernels::getTensorData<float>(scratch_data.get()),
188 kernels::getTensorData<float>(output_data), rank, input_size, batch_size, num_filters,
189 num_units, memory_size, options->fused_activation_function());
194 assert(
false &&
"Unsupported type.");
void SVDF(const float *input_data, const float *weights_feature_data, const float *weights_time_data, const float *bias_data, float *state_data, float *scratch_data, float *output_data, const int rank, const int input_size, const int batch_size, const int num_filters, const int num_units, const int memory_size, const circle::ActivationFunctionType activation)