ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter_pal::lstm_internal Namespace Reference

Data Structures

struct  LstmSizeInfo
 
class  LstmStepManager
 

Functions

template<typename InputType , typename OutputType >
void mulElementwise (int size, const ArithmeticParams *params, const InputType *input1_data, const InputType *input2_data, OutputType *output_data)
 
void mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
 
void mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data)
 
void addElementWise (const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input, int16_t *output)
 
void tanh (int32_t cell_state_scale_power, const luci_interpreter::RuntimeShape &input_data_shape, int16_t *input_data, const luci_interpreter::RuntimeShape &output_data_shape, int16_t *output_data)
 
void sigmoid (const luci_interpreter::RuntimeShape &data_shape, int16_t *data)
 
void clipping (const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info, int16_t *vector)
 
void mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const float *input1_data, const float *input2_data, float *output_data)
 
void addElementWise (const float *input_1, const float *input_2, int n_batch, int n_input, float *output)
 
void tanh (int32_t, const luci_interpreter::RuntimeShape &input_data_shape, float *input_data, const luci_interpreter::RuntimeShape &output_data_shape, float *output_data)
 
void sigmoid (const luci_interpreter::RuntimeShape &data_shape, float *data)
 
void clipping (const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info, float *vector)
 
template<typename ActivationType , typename WeightType , typename CellType , typename BiasType >
void calculateLstmGate (const LstmStepManager *step_info, const luci_interpreter::lstm::GateParameters *gate_params, ActivationType *input_data, const circle::Tensor *input_weight, const circle::Tensor *input_bias, ActivationType *recurrent_data, const circle::Tensor *recurrent_weight, const circle::Tensor *recurrent_bias, CellType *gate_output, CellType *fc_output_buffer, const FusedActivation activation, luci_interpreter::BaseRuntimeGraph *runtime_graph)
 
template<typename CellType , typename ActivationType >
void updateLstmHidden (const LstmStepManager *step_info, CellType *cell_state_data_base, ActivationType *hidden_state_data, const CellType *output_gate_output, const ArithmeticParams *mul_params, int32_t cell_state_scale_power, CellType *buffer)
 
template<typename CellType >
void updateLstmCell (const LstmStepManager *step_info, CellType *cell_state_data, CellType *forget_gate_output, const CellType *input_gate_output, const CellType *cell_gate_output, const ArithmeticParams &forget_cell_mul_params, const ArithmeticParams &input_mul_params, const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer)
 
template<typename ActivationType , typename WeightType , typename CellType , typename BiasType >
void lstmStep (luci_interpreter::lstm::LSTMStruct *lstm_struct, luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info, luci_interpreter::lstm::CellStateInfo *cell_state_info, ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0, CellType *scratch1, CellType *scratch2, CellType *scratch3, luci_interpreter::BaseRuntimeGraph *runtime_graph)
 

Function Documentation

◆ addElementWise() [1/2]

void luci_interpreter_pal::lstm_internal::addElementWise ( const float *  input_1,
const float *  input_2,
int  n_batch,
int  n_input,
float *  output 
)

Definition at line 142 of file PALUnidirectionalSequenceLSTMCommon.h.

144{
145 for (int batch = 0; batch < n_batch; ++batch)
146 {
147 for (int i = 0; i < n_input; ++i)
148 {
149 const int index = batch * n_input + i;
150 output[index] = input_1[index] + input_2[index];
151 }
152 }
153}

◆ addElementWise() [2/2]

void luci_interpreter_pal::lstm_internal::addElementWise ( const int16_t *  input_1,
const int16_t *  input_2,
int  n_batch,
int  n_input,
int16_t *  output 
)

Definition at line 84 of file PALUnidirectionalSequenceLSTMCommon.h.

86{
87 for (int batch = 0; batch < n_batch; ++batch)
88 {
89 for (int i = 0; i < n_input; ++i)
90 {
91 const int index = batch * n_input + i;
92 int32_t sum = input_1[index] + input_2[index];
93 const int32_t sum_clamped =
94 std::min(static_cast<int32_t>(std::numeric_limits<int16_t>::max()),
95 std::max(static_cast<int32_t>(std::numeric_limits<int16_t>::min()), sum));
96 output[index] = static_cast<int16_t>(sum_clamped);
97 }
98 }
99}

Referenced by calculateLstmGate(), and updateLstmCell().

◆ calculateLstmGate()

template<typename ActivationType , typename WeightType , typename CellType , typename BiasType >
void luci_interpreter_pal::lstm_internal::calculateLstmGate ( const LstmStepManager step_info,
const luci_interpreter::lstm::GateParameters gate_params,
ActivationType *  input_data,
const circle::Tensor *  input_weight,
const circle::Tensor *  input_bias,
ActivationType *  recurrent_data,
const circle::Tensor *  recurrent_weight,
const circle::Tensor *  recurrent_bias,
CellType *  gate_output,
CellType *  fc_output_buffer,
const FusedActivation  activation,
luci_interpreter::BaseRuntimeGraph runtime_graph 
)

Definition at line 278 of file PALUnidirectionalSequenceLSTMCommon.h.

291{
292 // Input FC
293 const auto gate_output_shape = step_info->stateShape();
294 {
295 FullyConnectedParams op_params{};
296 op_params.input_offset = gate_params->input_fc_params.input_offset;
297 op_params.weights_offset = gate_params->input_fc_params.weights_offset;
298 op_params.output_offset = gate_params->input_fc_params.output_offset;
299 op_params.output_multiplier = gate_params->input_fc_params.output_multiplier;
300 op_params.output_shift = gate_params->input_fc_params.output_shift;
301 op_params.quantized_activation_min = gate_params->input_fc_params.quantized_activation_min;
302 op_params.quantized_activation_max = gate_params->input_fc_params.quantized_activation_max;
303 op_params.float_activation_max = gate_params->input_fc_params.float_activation_max;
304 op_params.float_activation_min = gate_params->input_fc_params.float_activation_min;
305
306 int32_t input_weight_shape[luci_interpreter::kMaxSmallSize];
307 luci_interpreter::kernels::getTensorDims(input_weight, runtime_graph, input_weight_shape);
308
309 FullyConnected(op_params, step_info->inputShape().dimsData(),
310 input_data + step_info->inputOffset(), input_weight_shape,
311 luci_interpreter::kernels::getTensorData<WeightType>(
312 runtime_graph->getConstDataByTensor(input_weight)),
313 luci_interpreter::kernels::getTensorData<BiasType>(
314 runtime_graph->getConstDataByTensor(input_bias)),
315 gate_output_shape.dimsData(), gate_output, gate_output_shape.dimensionsCount(),
316 luci_interpreter::Tensor::num_dims(input_weight));
317 }
318
319 // Recurrent FC
320 {
321 FullyConnectedParams op_params{};
322 op_params.input_offset = gate_params->recurrent_fc_params.input_offset;
323 op_params.weights_offset = gate_params->recurrent_fc_params.weights_offset;
324 op_params.output_offset = gate_params->recurrent_fc_params.output_offset;
325 op_params.output_multiplier = gate_params->recurrent_fc_params.output_multiplier;
326 op_params.output_shift = gate_params->recurrent_fc_params.output_shift;
327 op_params.quantized_activation_min = gate_params->recurrent_fc_params.quantized_activation_min;
328 op_params.quantized_activation_max = gate_params->recurrent_fc_params.quantized_activation_max;
329 op_params.float_activation_max = gate_params->recurrent_fc_params.float_activation_max;
330 op_params.float_activation_min = gate_params->recurrent_fc_params.float_activation_min;
331
332 int32_t recurrent_weight_shape[luci_interpreter::kMaxSmallSize];
333 luci_interpreter::kernels::getTensorDims(recurrent_weight, runtime_graph,
334 recurrent_weight_shape);
335
336 FullyConnected(op_params, step_info->stateShape().dimsData(),
337 recurrent_data + step_info->hiddenStateOffset(), recurrent_weight_shape,
338 luci_interpreter::kernels::getTensorData<WeightType>(
339 runtime_graph->getConstDataByTensor(recurrent_weight)),
340 luci_interpreter::kernels::getTensorData<BiasType>(
341 runtime_graph->getConstDataByTensor(recurrent_bias)),
342 gate_output_shape.dimsData(), fc_output_buffer,
343 gate_output_shape.dimensionsCount(),
344 luci_interpreter::Tensor::num_dims(recurrent_weight));
345
346 addElementWise(gate_output, fc_output_buffer, /*n_batch=*/gate_output_shape.dimsData()[0],
347 /*n_state=*/gate_output_shape.dimsData()[1], gate_output);
348
349 switch (activation)
350 {
351 case FusedActivation::kTfLiteActSigmoid:
352 sigmoid(gate_output_shape, gate_output);
353 break;
354 case FusedActivation::kTfLiteActTanh:
355 {
356 // Set the scale power to -12 to avoid shift
357 tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output, gate_output_shape,
358 gate_output);
359 }
360 break;
361 default:
362 // Only Sigmoid or Tanh is used.
363 assert(false && "Only Sigmoid or Tanh is used");
364 }
365 }
366}
void FullyConnected(const float *input_data, const Dims< 4 > &input_dims, const float *weights_data, const Dims< 4 > &weights_dims, const float *bias_data, const Dims< 4 > &bias_dims, float *output_data, const Dims< 4 > &output_dims)
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
void getTensorDims(const circle::Tensor *tensor, BaseRuntimeGraph *runtime_graph, int32_t *dims)
Definition Utils.h:121
void sigmoid(const luci_interpreter::RuntimeShape &data_shape, int16_t *data)
void tanh(int32_t cell_state_scale_power, const luci_interpreter::RuntimeShape &input_data_shape, int16_t *input_data, const luci_interpreter::RuntimeShape &output_data_shape, int16_t *output_data)
void addElementWise(const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input, int16_t *output)
luci_interpreter_pal::FullyConnectedParams recurrent_fc_params
luci_interpreter_pal::FullyConnectedParams input_fc_params

References addElementWise(), luci_interpreter::RuntimeShape::dimsData(), luci_interpreter_pal::FullyConnectedParams::float_activation_max, luci_interpreter_pal::FullyConnectedParams::float_activation_min, luci_interpreter::RuntimeGraph::getConstDataByTensor(), luci_interpreter::kernels::getTensorDims(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), luci_interpreter::lstm::GateParameters::input_fc_params, luci_interpreter_pal::FullyConnectedParams::input_offset, luci_interpreter_pal::lstm_internal::LstmStepManager::inputOffset(), luci_interpreter_pal::lstm_internal::LstmStepManager::inputShape(), luci_interpreter_pal::FullyConnectedParams::output_multiplier, luci_interpreter_pal::FullyConnectedParams::output_offset, luci_interpreter_pal::FullyConnectedParams::output_shift, luci_interpreter_pal::FullyConnectedParams::quantized_activation_max, luci_interpreter_pal::FullyConnectedParams::quantized_activation_min, luci_interpreter::lstm::GateParameters::recurrent_fc_params, sigmoid(), luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape(), tanh(), and luci_interpreter_pal::FullyConnectedParams::weights_offset.

◆ clipping() [1/2]

void luci_interpreter_pal::lstm_internal::clipping ( const int  v_size,
const luci_interpreter::lstm::CellStateInfo cell_state_info,
float *  vector 
)

Definition at line 168 of file PALUnidirectionalSequenceLSTMCommon.h.

170{
171 for (int i = 0; i < v_size; i++)
172 {
173 vector[i] =
174 std::max(std::min(cell_state_info->cell_clip, vector[i]), -cell_state_info->cell_clip);
175 }
176}

References luci_interpreter::lstm::CellStateInfo::cell_clip.

◆ clipping() [2/2]

void luci_interpreter_pal::lstm_internal::clipping ( const int  v_size,
const luci_interpreter::lstm::CellStateInfo cell_state_info,
int16_t *  vector 
)

Definition at line 122 of file PALUnidirectionalSequenceLSTMCommon.h.

124{
125 for (int i = 0; i < v_size; i++)
126 {
127 vector[i] = std::max(std::min(cell_state_info->quantized_cell_clip, vector[i]),
128 static_cast<int16_t>(-cell_state_info->quantized_cell_clip));
129 }
130}

References luci_interpreter::lstm::CellStateInfo::quantized_cell_clip.

Referenced by updateLstmCell().

◆ lstmStep()

template<typename ActivationType , typename WeightType , typename CellType , typename BiasType >
void luci_interpreter_pal::lstm_internal::lstmStep ( luci_interpreter::lstm::LSTMStruct lstm_struct,
luci_interpreter::lstm::LSTMParameters lstm_params,
LstmStepManager step_info,
luci_interpreter::lstm::CellStateInfo cell_state_info,
ActivationType *  output_state_data,
CellType *  cell_state_data,
CellType *  scratch0,
CellType *  scratch1,
CellType *  scratch2,
CellType *  scratch3,
luci_interpreter::BaseRuntimeGraph runtime_graph 
)

Definition at line 422 of file PALUnidirectionalSequenceLSTMCommon.h.

428{
429 /*Step1: Calculate gate outputs to prepare cell state update*/
430 CellType *gate_internal_buffer = scratch3;
431 CellType *forget_gate_output = scratch0;
432
433 auto input_data = luci_interpreter::kernels::getTensorData<ActivationType>(
434 runtime_graph->getDataByTensor(lstm_struct->input()));
435
436 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
437 step_info, &lstm_params->forget_gate_parameters,
438 // Input FC
439 input_data, lstm_struct->input_to_forget_weights(), lstm_struct->forget_gate_bias(),
440 // Recurrent FC
441 output_state_data, lstm_struct->recurrent_to_forget_weights(), nullptr,
442 // Output
443 forget_gate_output, gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
444
445 // Input Gate calculation;
446 CellType *input_gate_output = scratch1;
447 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
448 step_info, &lstm_params->input_gate_parameters,
449 // Input FC
450 input_data, lstm_struct->input_to_input_weights(), lstm_struct->input_gate_bias(),
451 // Recurrent FC
452 output_state_data, lstm_struct->recurrent_to_input_weights(),
453 /*recurrent_bias*/ nullptr,
454 // Output
455 input_gate_output,
456 // Scratch arrays
457 gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
458
459 // Cell Gate calculation
460 CellType *cell_gate_output = scratch2;
461 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
462 step_info, &lstm_params->cell_gate_parameters,
463 // Input FC
464 input_data, lstm_struct->input_to_cell_weights(), lstm_struct->cell_gate_bias(),
465 // Recurrent FC
466 output_state_data, lstm_struct->recurrent_to_cell_weights(),
467 /*recurrent_bias*/ nullptr,
468 // Output
469 cell_gate_output,
470 // Scratch arrays
471 gate_internal_buffer, FusedActivation::kTfLiteActTanh, runtime_graph);
472
473 /*Step2: update the cell state */
474 {
475 // const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
476 CellType *updated_input_buffer = scratch1; // reuse buffer
477
478 updateLstmCell<CellType>(
479 step_info, cell_state_data, forget_gate_output, input_gate_output, cell_gate_output,
481 lstm_params->inter_gate_parameters.input_mul_params, cell_state_info, updated_input_buffer);
482 }
483
484 {
485 /*Step3: update the hidden state */
486 CellType *output_gate_output = scratch1; // reuse buffer
487 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
488 step_info, &lstm_params->output_gate_parameters,
489 // Input FC
490 input_data, lstm_struct->input_to_output_weights(), lstm_struct->output_gate_bias(),
491 // Recurrent FC
492 output_state_data, lstm_struct->recurrent_to_output_weights(), nullptr,
493 // Output
494 output_gate_output,
495 // Scratch arrays
496 gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
497 CellType *tanh_activated_cell_buffer = scratch0; // reuse buffer
498 updateLstmHidden<CellType, ActivationType>(
499 step_info, cell_state_data, output_state_data, output_gate_output,
501 cell_state_info->cell_state_scale_power, tanh_activated_cell_buffer);
502
503 ActivationType *output_ptr = luci_interpreter::kernels::getTensorData<ActivationType>(
504 runtime_graph->getDataByTensor(lstm_struct->output()));
505 std::memcpy(output_ptr + step_info->outputOffset(),
506 output_state_data + step_info->hiddenStateOffset(),
507 step_info->stateShape().flatSize() * sizeof(ActivationType));
508 }
509}
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
luci_interpreter_pal::ArithmeticParams output_mul_params
luci_interpreter_pal::ArithmeticParams forget_cell_mul_params
luci_interpreter_pal::ArithmeticParams input_mul_params

References luci_interpreter::lstm::LSTMStruct::cell_gate_bias(), luci_interpreter::lstm::LSTMParameters::cell_gate_parameters, luci_interpreter::lstm::CellStateInfo::cell_state_scale_power, luci_interpreter::RuntimeShape::flatSize(), luci_interpreter::lstm::InterGateParameters::forget_cell_mul_params, luci_interpreter::lstm::LSTMStruct::forget_gate_bias(), luci_interpreter::lstm::LSTMParameters::forget_gate_parameters, luci_interpreter::RuntimeGraph::getDataByTensor(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), luci_interpreter::lstm::LSTMStruct::input(), luci_interpreter::lstm::LSTMStruct::input_gate_bias(), luci_interpreter::lstm::LSTMParameters::input_gate_parameters, luci_interpreter::lstm::InterGateParameters::input_mul_params, luci_interpreter::lstm::LSTMStruct::input_to_cell_weights(), luci_interpreter::lstm::LSTMStruct::input_to_forget_weights(), luci_interpreter::lstm::LSTMStruct::input_to_input_weights(), luci_interpreter::lstm::LSTMStruct::input_to_output_weights(), luci_interpreter::lstm::LSTMParameters::inter_gate_parameters, luci_interpreter::lstm::LSTMStruct::output(), luci_interpreter::lstm::LSTMStruct::output_gate_bias(), luci_interpreter::lstm::LSTMParameters::output_gate_parameters, luci_interpreter::lstm::InterGateParameters::output_mul_params, luci_interpreter_pal::lstm_internal::LstmStepManager::outputOffset(), luci_interpreter::lstm::LSTMStruct::recurrent_to_cell_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_forget_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_input_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_output_weights(), and luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape().

◆ mul() [1/3]

void luci_interpreter_pal::lstm_internal::mul ( const luci_interpreter::RuntimeShape shape,
const ArithmeticParams params,
const float *  input1_data,
const float *  input2_data,
float *  output_data 
)

Definition at line 135 of file PALUnidirectionalSequenceLSTMCommon.h.

137{
138 const int flat_size = shape.flatSize();
139 return luci_interpreter_pal::Mul(*params, flat_size, input1_data, input2_data, output_data);
140}

References luci_interpreter::RuntimeShape::flatSize().

◆ mul() [2/3]

void luci_interpreter_pal::lstm_internal::mul ( const luci_interpreter::RuntimeShape shape,
const ArithmeticParams params,
const int16_t *  input1_data,
const int16_t *  input2_data,
int16_t *  output_data 
)

Definition at line 78 of file PALUnidirectionalSequenceLSTMCommon.h.

80{
81 return mulElementwise(shape.flatSize(), params, input1_data, input2_data, output_data);
82}
void mulElementwise(int size, const ArithmeticParams *params, const InputType *input1_data, const InputType *input2_data, OutputType *output_data)

References luci_interpreter::RuntimeShape::flatSize(), and mulElementwise().

◆ mul() [3/3]

void luci_interpreter_pal::lstm_internal::mul ( const luci_interpreter::RuntimeShape shape,
const ArithmeticParams params,
const int16_t *  input1_data,
const int16_t *  input2_data,
int8_t *  output_data 
)

Definition at line 70 of file PALUnidirectionalSequenceLSTMCommon.h.

72{
73 return mulElementwise<int16_t, int8_t>(shape.flatSize(), params, input1_data, input2_data,
74 output_data);
75}

References luci_interpreter::RuntimeShape::flatSize().

Referenced by updateLstmCell(), and updateLstmHidden().

◆ mulElementwise()

template<typename InputType , typename OutputType >
void luci_interpreter_pal::lstm_internal::mulElementwise ( int  size,
const ArithmeticParams params,
const InputType *  input1_data,
const InputType *  input2_data,
OutputType *  output_data 
)

Definition at line 51 of file PALUnidirectionalSequenceLSTMCommon.h.

53{
54 for (int i = 0; i < size; ++i)
55 {
56 const int32_t input1_val = params->input1_offset + input1_data[i];
57 const int32_t input2_val = params->input2_offset + input2_data[i];
58 const int32_t unclamped_result =
59 params->output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
60 params->output_multiplier,
61 params->output_shift);
62 const int32_t clamped_output =
63 std::min(params->quantized_activation_max,
64 std::max(params->quantized_activation_min, unclamped_result));
65 output_data[i] = static_cast<OutputType>(clamped_output);
66 }
67}
int32_t multiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition PALUtils.h:77
int32_t size[5]
Definition Slice.cpp:35

References luci_interpreter_pal::ArithmeticParams::input1_offset, luci_interpreter_pal::ArithmeticParams::input2_offset, luci_interpreter_pal::multiplyByQuantizedMultiplier(), luci_interpreter_pal::ArithmeticParams::output_multiplier, luci_interpreter_pal::ArithmeticParams::output_offset, luci_interpreter_pal::ArithmeticParams::output_shift, luci_interpreter_pal::ArithmeticParams::quantized_activation_max, luci_interpreter_pal::ArithmeticParams::quantized_activation_min, and size.

Referenced by mul().

◆ sigmoid() [1/2]

void luci_interpreter_pal::lstm_internal::sigmoid ( const luci_interpreter::RuntimeShape data_shape,
float *  data 
)

Definition at line 162 of file PALUnidirectionalSequenceLSTMCommon.h.

163{
164 const int flat_size = data_shape.flatSize();
165 luci_interpreter_pal::Logistic(flat_size, data, data);
166}
void Logistic(const int flat_size, const float *input_data, float *output_data)
Definition PALGRU.h:26

References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Logistic().

◆ sigmoid() [2/2]

void luci_interpreter_pal::lstm_internal::sigmoid ( const luci_interpreter::RuntimeShape data_shape,
int16_t *  data 
)

Definition at line 117 of file PALUnidirectionalSequenceLSTMCommon.h.

118{
119 luci_interpreter_pal::Logistic(0, 0, data_shape.flatSize(), data, data);
120}

References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Logistic().

Referenced by calculateLstmGate().

◆ tanh() [1/2]

void luci_interpreter_pal::lstm_internal::tanh ( int32_t  cell_state_scale_power,
const luci_interpreter::RuntimeShape input_data_shape,
int16_t *  input_data,
const luci_interpreter::RuntimeShape output_data_shape,
int16_t *  output_data 
)

Definition at line 101 of file PALUnidirectionalSequenceLSTMCommon.h.

104{
105 int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
106 int32_t input_multiplier = 0;
107 if (tanh_input_left_shift < 0) /* handling negative shift value */
108 {
109 tanh_input_left_shift = -tanh_input_left_shift;
110 input_multiplier = 3;
111 }
112 const int flat_size = input_data_shape.flatSize();
113 luci_interpreter_pal::Tanh(input_multiplier, tanh_input_left_shift, flat_size, input_data,
114 output_data);
115}
void Tanh(const int flat_size, const float *input_data, float *output_data)
Definition PALTanh.h:26

References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Tanh().

Referenced by calculateLstmGate(), and updateLstmHidden().

◆ tanh() [2/2]

void luci_interpreter_pal::lstm_internal::tanh ( int32_t  ,
const luci_interpreter::RuntimeShape input_data_shape,
float *  input_data,
const luci_interpreter::RuntimeShape output_data_shape,
float *  output_data 
)

Definition at line 155 of file PALUnidirectionalSequenceLSTMCommon.h.

157{
158 const int flat_size = input_data_shape.flatSize();
159 luci_interpreter_pal::Tanh(flat_size, input_data, output_data);
160}

References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Tanh().

◆ updateLstmCell()

template<typename CellType >
void luci_interpreter_pal::lstm_internal::updateLstmCell ( const LstmStepManager step_info,
CellType *  cell_state_data,
CellType *  forget_gate_output,
const CellType *  input_gate_output,
const CellType *  cell_gate_output,
const ArithmeticParams forget_cell_mul_params,
const ArithmeticParams input_mul_params,
const luci_interpreter::lstm::CellStateInfo cell_state_info,
CellType *  buffer 
)

Definition at line 391 of file PALUnidirectionalSequenceLSTMCommon.h.

399{
400 auto cell_state_shape = step_info->stateShape();
401 // Forget Gate x Cell State
402 mul(cell_state_shape, &forget_cell_mul_params, forget_gate_output,
403 cell_state_data + step_info->cellStateOffset(),
404 cell_state_data + step_info->cellStateOffset());
405 // Input Gate x Cell Gate
406 mul(cell_state_shape, &input_mul_params, input_gate_output, cell_gate_output, buffer);
407
408 // Update the cell state
409 addElementWise(cell_state_data + step_info->cellStateOffset(), buffer,
410 /*n_batch=*/cell_state_shape.dimsData()[0],
411 /*n_state=*/cell_state_shape.dimsData()[1],
412 cell_state_data + step_info->cellStateOffset());
413
414 if (cell_state_info->cell_clip > 0)
415 {
416 clipping(cell_state_shape.flatSize(), cell_state_info,
417 cell_state_data + step_info->cellStateOffset());
418 }
419}
void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info, int16_t *vector)

References addElementWise(), luci_interpreter::lstm::CellStateInfo::cell_clip, luci_interpreter_pal::lstm_internal::LstmStepManager::cellStateOffset(), clipping(), mul(), and luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape().

◆ updateLstmHidden()

template<typename CellType , typename ActivationType >
void luci_interpreter_pal::lstm_internal::updateLstmHidden ( const LstmStepManager step_info,
CellType *  cell_state_data_base,
ActivationType *  hidden_state_data,
const CellType *  output_gate_output,
const ArithmeticParams mul_params,
int32_t  cell_state_scale_power,
CellType *  buffer 
)

Definition at line 372 of file PALUnidirectionalSequenceLSTMCommon.h.

376{
377 auto cell_state_shape = step_info->stateShape();
378 CellType *cell_state_data = cell_state_data_base + step_info->cellStateOffset();
379 // Tanh(cell_state)
380 tanh(cell_state_scale_power, cell_state_shape, cell_state_data, cell_state_shape, buffer);
381 // Update the hidden state
382 mul(cell_state_shape, mul_params, buffer, output_gate_output,
383 hidden_state_data + step_info->hiddenStateOffset());
384}

References luci_interpreter_pal::lstm_internal::LstmStepManager::cellStateOffset(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), mul(), luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape(), and tanh().