ONE - On-device Neural Engine
|
Data Structures | |
struct | LstmSizeInfo |
class | LstmStepManager |
Functions | |
template<typename InputType , typename OutputType > | |
void | mulElementwise (int size, const ArithmeticParams *params, const InputType *input1_data, const InputType *input2_data, OutputType *output_data) |
void | mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data) |
void | mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data) |
void | addElementWise (const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input, int16_t *output) |
void | tanh (int32_t cell_state_scale_power, const luci_interpreter::RuntimeShape &input_data_shape, int16_t *input_data, const luci_interpreter::RuntimeShape &output_data_shape, int16_t *output_data) |
void | sigmoid (const luci_interpreter::RuntimeShape &data_shape, int16_t *data) |
void | clipping (const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info, int16_t *vector) |
void | mul (const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params, const float *input1_data, const float *input2_data, float *output_data) |
void | addElementWise (const float *input_1, const float *input_2, int n_batch, int n_input, float *output) |
void | tanh (int32_t, const luci_interpreter::RuntimeShape &input_data_shape, float *input_data, const luci_interpreter::RuntimeShape &output_data_shape, float *output_data) |
void | sigmoid (const luci_interpreter::RuntimeShape &data_shape, float *data) |
void | clipping (const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info, float *vector) |
template<typename ActivationType , typename WeightType , typename CellType , typename BiasType > | |
void | calculateLstmGate (const LstmStepManager *step_info, const luci_interpreter::lstm::GateParameters *gate_params, ActivationType *input_data, const circle::Tensor *input_weight, const circle::Tensor *input_bias, ActivationType *recurrent_data, const circle::Tensor *recurrent_weight, const circle::Tensor *recurrent_bias, CellType *gate_output, CellType *fc_output_buffer, const FusedActivation activation, luci_interpreter::BaseRuntimeGraph *runtime_graph) |
template<typename CellType , typename ActivationType > | |
void | updateLstmHidden (const LstmStepManager *step_info, CellType *cell_state_data_base, ActivationType *hidden_state_data, const CellType *output_gate_output, const ArithmeticParams *mul_params, int32_t cell_state_scale_power, CellType *buffer) |
template<typename CellType > | |
void | updateLstmCell (const LstmStepManager *step_info, CellType *cell_state_data, CellType *forget_gate_output, const CellType *input_gate_output, const CellType *cell_gate_output, const ArithmeticParams &forget_cell_mul_params, const ArithmeticParams &input_mul_params, const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer) |
template<typename ActivationType , typename WeightType , typename CellType , typename BiasType > | |
void | lstmStep (luci_interpreter::lstm::LSTMStruct *lstm_struct, luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info, luci_interpreter::lstm::CellStateInfo *cell_state_info, ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0, CellType *scratch1, CellType *scratch2, CellType *scratch3, luci_interpreter::BaseRuntimeGraph *runtime_graph) |
void luci_interpreter_pal::lstm_internal::addElementWise | ( | const float * | input_1, |
const float * | input_2, | ||
int | n_batch, | ||
int | n_input, | ||
float * | output | ||
) |
Definition at line 142 of file PALUnidirectionalSequenceLSTMCommon.h.
void luci_interpreter_pal::lstm_internal::addElementWise | ( | const int16_t * | input_1, |
const int16_t * | input_2, | ||
int | n_batch, | ||
int | n_input, | ||
int16_t * | output | ||
) |
Definition at line 84 of file PALUnidirectionalSequenceLSTMCommon.h.
Referenced by calculateLstmGate(), and updateLstmCell().
void luci_interpreter_pal::lstm_internal::calculateLstmGate | ( | const LstmStepManager * | step_info, |
const luci_interpreter::lstm::GateParameters * | gate_params, | ||
ActivationType * | input_data, | ||
const circle::Tensor * | input_weight, | ||
const circle::Tensor * | input_bias, | ||
ActivationType * | recurrent_data, | ||
const circle::Tensor * | recurrent_weight, | ||
const circle::Tensor * | recurrent_bias, | ||
CellType * | gate_output, | ||
CellType * | fc_output_buffer, | ||
const FusedActivation | activation, | ||
luci_interpreter::BaseRuntimeGraph * | runtime_graph | ||
) |
Definition at line 278 of file PALUnidirectionalSequenceLSTMCommon.h.
References addElementWise(), luci_interpreter::RuntimeShape::dimsData(), luci_interpreter_pal::FullyConnectedParams::float_activation_max, luci_interpreter_pal::FullyConnectedParams::float_activation_min, luci_interpreter::RuntimeGraph::getConstDataByTensor(), luci_interpreter::kernels::getTensorDims(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), luci_interpreter::lstm::GateParameters::input_fc_params, luci_interpreter_pal::FullyConnectedParams::input_offset, luci_interpreter_pal::lstm_internal::LstmStepManager::inputOffset(), luci_interpreter_pal::lstm_internal::LstmStepManager::inputShape(), luci_interpreter_pal::FullyConnectedParams::output_multiplier, luci_interpreter_pal::FullyConnectedParams::output_offset, luci_interpreter_pal::FullyConnectedParams::output_shift, luci_interpreter_pal::FullyConnectedParams::quantized_activation_max, luci_interpreter_pal::FullyConnectedParams::quantized_activation_min, luci_interpreter::lstm::GateParameters::recurrent_fc_params, sigmoid(), luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape(), tanh(), and luci_interpreter_pal::FullyConnectedParams::weights_offset.
void luci_interpreter_pal::lstm_internal::clipping | ( | const int | v_size, |
const luci_interpreter::lstm::CellStateInfo * | cell_state_info, | ||
float * | vector | ||
) |
Definition at line 168 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::lstm::CellStateInfo::cell_clip.
void luci_interpreter_pal::lstm_internal::clipping | ( | const int | v_size, |
const luci_interpreter::lstm::CellStateInfo * | cell_state_info, | ||
int16_t * | vector | ||
) |
Definition at line 122 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::lstm::CellStateInfo::quantized_cell_clip.
Referenced by updateLstmCell().
void luci_interpreter_pal::lstm_internal::lstmStep | ( | luci_interpreter::lstm::LSTMStruct * | lstm_struct, |
luci_interpreter::lstm::LSTMParameters * | lstm_params, | ||
LstmStepManager * | step_info, | ||
luci_interpreter::lstm::CellStateInfo * | cell_state_info, | ||
ActivationType * | output_state_data, | ||
CellType * | cell_state_data, | ||
CellType * | scratch0, | ||
CellType * | scratch1, | ||
CellType * | scratch2, | ||
CellType * | scratch3, | ||
luci_interpreter::BaseRuntimeGraph * | runtime_graph | ||
) |
Definition at line 422 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::lstm::LSTMStruct::cell_gate_bias(), luci_interpreter::lstm::LSTMParameters::cell_gate_parameters, luci_interpreter::lstm::CellStateInfo::cell_state_scale_power, luci_interpreter::RuntimeShape::flatSize(), luci_interpreter::lstm::InterGateParameters::forget_cell_mul_params, luci_interpreter::lstm::LSTMStruct::forget_gate_bias(), luci_interpreter::lstm::LSTMParameters::forget_gate_parameters, luci_interpreter::RuntimeGraph::getDataByTensor(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), luci_interpreter::lstm::LSTMStruct::input(), luci_interpreter::lstm::LSTMStruct::input_gate_bias(), luci_interpreter::lstm::LSTMParameters::input_gate_parameters, luci_interpreter::lstm::InterGateParameters::input_mul_params, luci_interpreter::lstm::LSTMStruct::input_to_cell_weights(), luci_interpreter::lstm::LSTMStruct::input_to_forget_weights(), luci_interpreter::lstm::LSTMStruct::input_to_input_weights(), luci_interpreter::lstm::LSTMStruct::input_to_output_weights(), luci_interpreter::lstm::LSTMParameters::inter_gate_parameters, luci_interpreter::lstm::LSTMStruct::output(), luci_interpreter::lstm::LSTMStruct::output_gate_bias(), luci_interpreter::lstm::LSTMParameters::output_gate_parameters, luci_interpreter::lstm::InterGateParameters::output_mul_params, luci_interpreter_pal::lstm_internal::LstmStepManager::outputOffset(), luci_interpreter::lstm::LSTMStruct::recurrent_to_cell_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_forget_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_input_weights(), luci_interpreter::lstm::LSTMStruct::recurrent_to_output_weights(), and luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape().
void luci_interpreter_pal::lstm_internal::mul | ( | const luci_interpreter::RuntimeShape & | shape, |
const ArithmeticParams * | params, | ||
const float * | input1_data, | ||
const float * | input2_data, | ||
float * | output_data | ||
) |
Definition at line 135 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize().
void luci_interpreter_pal::lstm_internal::mul | ( | const luci_interpreter::RuntimeShape & | shape, |
const ArithmeticParams * | params, | ||
const int16_t * | input1_data, | ||
const int16_t * | input2_data, | ||
int16_t * | output_data | ||
) |
Definition at line 78 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize(), and mulElementwise().
void luci_interpreter_pal::lstm_internal::mul | ( | const luci_interpreter::RuntimeShape & | shape, |
const ArithmeticParams * | params, | ||
const int16_t * | input1_data, | ||
const int16_t * | input2_data, | ||
int8_t * | output_data | ||
) |
Definition at line 70 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize().
Referenced by updateLstmCell(), and updateLstmHidden().
void luci_interpreter_pal::lstm_internal::mulElementwise | ( | int | size, |
const ArithmeticParams * | params, | ||
const InputType * | input1_data, | ||
const InputType * | input2_data, | ||
OutputType * | output_data | ||
) |
Definition at line 51 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter_pal::ArithmeticParams::input1_offset, luci_interpreter_pal::ArithmeticParams::input2_offset, luci_interpreter_pal::multiplyByQuantizedMultiplier(), luci_interpreter_pal::ArithmeticParams::output_multiplier, luci_interpreter_pal::ArithmeticParams::output_offset, luci_interpreter_pal::ArithmeticParams::output_shift, luci_interpreter_pal::ArithmeticParams::quantized_activation_max, luci_interpreter_pal::ArithmeticParams::quantized_activation_min, and size.
Referenced by mul().
void luci_interpreter_pal::lstm_internal::sigmoid | ( | const luci_interpreter::RuntimeShape & | data_shape, |
float * | data | ||
) |
Definition at line 162 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Logistic().
void luci_interpreter_pal::lstm_internal::sigmoid | ( | const luci_interpreter::RuntimeShape & | data_shape, |
int16_t * | data | ||
) |
Definition at line 117 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Logistic().
Referenced by calculateLstmGate().
void luci_interpreter_pal::lstm_internal::tanh | ( | int32_t | cell_state_scale_power, |
const luci_interpreter::RuntimeShape & | input_data_shape, | ||
int16_t * | input_data, | ||
const luci_interpreter::RuntimeShape & | output_data_shape, | ||
int16_t * | output_data | ||
) |
Definition at line 101 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Tanh().
Referenced by calculateLstmGate(), and updateLstmHidden().
void luci_interpreter_pal::lstm_internal::tanh | ( | int32_t | , |
const luci_interpreter::RuntimeShape & | input_data_shape, | ||
float * | input_data, | ||
const luci_interpreter::RuntimeShape & | output_data_shape, | ||
float * | output_data | ||
) |
Definition at line 155 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter::RuntimeShape::flatSize(), and luci_interpreter_pal::Tanh().
void luci_interpreter_pal::lstm_internal::updateLstmCell | ( | const LstmStepManager * | step_info, |
CellType * | cell_state_data, | ||
CellType * | forget_gate_output, | ||
const CellType * | input_gate_output, | ||
const CellType * | cell_gate_output, | ||
const ArithmeticParams & | forget_cell_mul_params, | ||
const ArithmeticParams & | input_mul_params, | ||
const luci_interpreter::lstm::CellStateInfo * | cell_state_info, | ||
CellType * | buffer | ||
) |
Definition at line 391 of file PALUnidirectionalSequenceLSTMCommon.h.
References addElementWise(), luci_interpreter::lstm::CellStateInfo::cell_clip, luci_interpreter_pal::lstm_internal::LstmStepManager::cellStateOffset(), clipping(), mul(), and luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape().
void luci_interpreter_pal::lstm_internal::updateLstmHidden | ( | const LstmStepManager * | step_info, |
CellType * | cell_state_data_base, | ||
ActivationType * | hidden_state_data, | ||
const CellType * | output_gate_output, | ||
const ArithmeticParams * | mul_params, | ||
int32_t | cell_state_scale_power, | ||
CellType * | buffer | ||
) |
Definition at line 372 of file PALUnidirectionalSequenceLSTMCommon.h.
References luci_interpreter_pal::lstm_internal::LstmStepManager::cellStateOffset(), luci_interpreter_pal::lstm_internal::LstmStepManager::hiddenStateOffset(), mul(), luci_interpreter_pal::lstm_internal::LstmStepManager::stateShape(), and tanh().