25#include "PALLogSoftmax.h"
33constexpr uint32_t inputTensorIdx = 0;
41 const circle::Tensor *
input =
nullptr;
42 const circle::Tensor *
output =
nullptr;
47 SISOHeader(execute_args, &input, &output, &input_data, &output_data);
51 switch (
input->type())
54 case circle::TensorType_FLOAT32:
60 const auto dim_count = inputs_shape.dimensionsCount();
62 const auto trailing_dim = dim_count - 1;
65 for (
int i = 0; i < inputs_shape.dimensionsCount(); ++i)
67 flat_size *= (i == trailing_dim) ? 1 : inputs_shape.dims(i);
73 assert(inputs_shape.dims(trailing_dim) == outputs_shape.dims(trailing_dim));
74 params.
row_size = inputs_shape.dims(trailing_dim);
76 status =
pal::LogSoftmax(params, core::utils::castInputData<float>(input_data),
77 core::utils::castOutputData<float>(output_data));
84 assert(
false &&
"Unsupported type.");
constexpr uint32_t outputTensorIdx
OMStatus LogSoftmax(const core::LogSoftmaxParams ¶ms, const float *input_data, float *output_data)
OMStatus SISOHeader(const OMExecuteArgs &execute_args, const circle::Tensor **input, const circle::Tensor **output, uint8_t **input_data, uint8_t **output_data)