63 const circle::Tensor *input =
nullptr;
64 const circle::Tensor *output =
nullptr;
66 uint8_t *input_data =
nullptr;
67 uint8_t *output_data =
nullptr;
71 const circle::SoftmaxOptions *options;
74 runtime_kernel.
readKernel(op_index, runtime_context);
76 input = runtime_kernel.
inputs[inputTensorIdx];
77 output = runtime_kernel.
outputs[outputTensorIdx];
79 assert(input !=
nullptr);
80 assert(output !=
nullptr);
82 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
86 input_data = runtime_kernel.
inputs_data[inputTensorIdx];
87 output_data = runtime_kernel.
outputs_data[outputTensorIdx];
89 options = runtime_kernel.
first_operator->builtin_options_as_SoftmaxOptions();
92 assert(input_data !=
nullptr);
93 assert(output_data !=
nullptr);
95 const float beta = options->beta();
102 const auto trailing_dim = dim_count - 1;
107 flat_size *= (i == trailing_dim) ? 1 : inputs_shape.
dims(i);
112 params.num_rows = flat_size;
113 params.row_size = std::min(inputs_shape.
dims(trailing_dim), outputs_shape.
dims(trailing_dim));
115 switch (input->type())
118 case circle::TensorType_FLOAT32:
121 status =
pal::Softmax(params, core::utils::castInputData<float>(input_data),
122 core::utils::castOutputData<float>(output_data));
127 case circle::TensorType_INT8:
129 assert(output->type() == circle::TensorType_INT8);
130 if (output->type() != circle::TensorType_INT8)
133 assert(input->quantization() !=
nullptr and output->quantization() !=
nullptr);
134 assert(input->quantization()->scale() !=
nullptr and
135 output->quantization()->scale() !=
nullptr);
136 assert(input->quantization()->zero_point() !=
nullptr and
137 output->quantization()->zero_point() !=
nullptr);
138 assert(input->quantization()->scale()->size() == 1 and
139 output->quantization()->scale()->size() == 1);
140 assert(input->quantization()->zero_point()->size() == 1 and
141 output->quantization()->zero_point()->size() == 1);
143 params.output_scale = output->quantization()->scale()->operator[](0);
144 params.input_scale = input->quantization()->scale()->operator[](0);
145 params.output_zp = output->quantization()->zero_point()->operator[](0);
146 params.input_zp = input->quantization()->zero_point()->operator[](0);
149 preprocessSoftmaxScaling(
static_cast<double>(params.beta),
150 static_cast<double>(params.input_scale), kScaledDiffIntegerBits,
151 ¶ms.input_multiplier, &left_shift);
152 params.input_left_shift = left_shift;
154 kScaledDiffIntegerBits, params.input_left_shift, 31);
156 status =
pal::Softmax(params, core::utils::castInputData<int8_t>(input_data),
157 core::utils::castOutputData<int8_t>(output_data));
164 assert(
false &&
"Unsupported type.");