24#include "PALSoftmax.h"
34constexpr uint32_t inputTensorIdx = 0;
37static const int kScaledDiffIntegerBits = 5;
38void preprocessSoftmaxScaling(
double beta,
double input_scale,
int input_integer_bits,
39 int32_t *quantized_multiplier,
int *left_shift)
41 const double max_real_multiplier = (1LL << 31) - 1.0;
42 const double input_beta_real_multiplier =
43 std::min<double>(beta * input_scale * (1 << (31 - input_integer_bits)), max_real_multiplier);
58 const circle::Tensor *
input =
nullptr;
59 const circle::Tensor *
output =
nullptr;
66 const circle::SoftmaxOptions *
options;
69 runtime_kernel.
readKernel(op_index, runtime_context);
74 assert(input !=
nullptr);
75 assert(output !=
nullptr);
77 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
87 assert(input_data !=
nullptr);
88 assert(output_data !=
nullptr);
90 const float beta =
options->beta();
95 const auto dim_count = inputs_shape.dimensionsCount();
97 const auto trailing_dim = dim_count - 1;
100 for (
int i = 0; i < inputs_shape.dimensionsCount(); ++i)
102 flat_size *= (i == trailing_dim) ? 1 : inputs_shape.dims(i);
107 params.num_rows = flat_size;
108 params.row_size = std::min(inputs_shape.dims(trailing_dim), outputs_shape.dims(trailing_dim));
110 switch (
input->type())
113 case circle::TensorType_FLOAT32:
116 status =
pal::Softmax(params, core::utils::castInputData<float>(input_data),
117 core::utils::castOutputData<float>(output_data));
122 case circle::TensorType_INT8:
124 assert(
output->type() == circle::TensorType_INT8);
125 if (
output->type() != circle::TensorType_INT8)
128 assert(
input->quantization() !=
nullptr and
output->quantization() !=
nullptr);
129 assert(
input->quantization()->scale() !=
nullptr and
130 output->quantization()->scale() !=
nullptr);
131 assert(
input->quantization()->zero_point() !=
nullptr and
132 output->quantization()->zero_point() !=
nullptr);
133 assert(
input->quantization()->scale()->size() == 1 and
134 output->quantization()->scale()->size() == 1);
135 assert(
input->quantization()->zero_point()->size() == 1 and
136 output->quantization()->zero_point()->size() == 1);
138 params.output_scale =
output->quantization()->scale()->operator[](0);
139 params.input_scale =
input->quantization()->scale()->operator[](0);
140 params.output_zp =
output->quantization()->zero_point()->operator[](0);
141 params.input_zp =
input->quantization()->zero_point()->operator[](0);
144 preprocessSoftmaxScaling(
static_cast<double>(params.beta),
145 static_cast<double>(params.input_scale), kScaledDiffIntegerBits,
146 ¶ms.input_multiplier, &left_shift);
147 params.input_left_shift = left_shift;
149 kScaledDiffIntegerBits, params.input_left_shift, 31);
151 status =
pal::Softmax(params, core::utils::castInputData<int8_t>(input_data),
152 core::utils::castOutputData<int8_t>(output_data));
159 assert(
false &&
"Unsupported type.");
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
OMStatus Softmax(const core::SoftmaxParams ¶ms, const T *input_data, U *output_data)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
int calculateInputRadius(int input_integer_bits, int input_left_shift, int total_signed_bits)
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage