ONE - On-device Neural Engine
Loading...
Searching...
No Matches
LogSoftmax.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/LogSoftmax.h"
18
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
22
23#include "PALLogSoftmax.h"
24
25namespace luci_interpreter
26{
27namespace kernels
28{
29
30LogSoftmax::LogSoftmax(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
31
33{
34 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
35 if (input()->element_type() == DataType::U8)
36 {
37 LUCI_INTERPRETER_CHECK(output()->scale() == 16. / 256);
38 LUCI_INTERPRETER_CHECK(output()->zero_point() == 255);
39
40 tflite::SoftmaxParams params{};
41
42 params.table = _table;
43 params.beta = 1.0;
44 luci_interpreter_pal::PopulateSoftmaxLookupTable(&params, input()->scale(), params.beta);
45 }
46 output()->resize(input()->shape());
47}
48
50{
51 switch (input()->element_type())
52 {
53 case DataType::FLOAT32:
54 evalFloat();
55 break;
56 case DataType::U8:
57 evalQuantized();
58 break;
59 default:
60 throw std::runtime_error("luci-intp LogSoftmax Unsupported type.");
61 }
62}
63
64void LogSoftmax::evalFloat() const
65{
66 tflite::SoftmaxParams params{};
67 tflite::reference_ops::LogSoftmax(params, getTensorShape(input()), getTensorData<float>(input()),
68 getTensorShape(output()), getTensorData<float>(output()));
69}
70
71void LogSoftmax::evalQuantized() const
72{
73 const auto input_shape = getTensorShape(input());
74 const auto output_shape = getTensorShape(output());
75 const auto input_scale = input()->scale();
76 uint8_t *output_data = getTensorData<uint8_t>(output());
77 const uint8_t *input_data = getTensorData<uint8_t>(input());
78 const float beta = 1.0;
79
80 tflite::SoftmaxParams params{};
81
82 params.table = const_cast<float *>(_table);
83 params.zero_point = output()->zero_point();
84 params.scale = output()->scale();
85
86 luci_interpreter_pal::InitializeParams(&params, input_scale, beta);
87 luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
88 output_data);
89}
90
91} // namespace kernels
92} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
float scale() const
Definition Tensor.h:109
int32_t zero_point() const
Definition Tensor.h:115
LogSoftmax(const Tensor *input, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194