ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Logistic.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Builders.h"
18#include "kernels/Utils.h"
19#include "PALLogistic.h"
20#include "SISOKernel.h"
21
22namespace luci_interpreter
23{
24
25void configure_kernel_CircleLogistic(const circle::Operator *cur_op,
26 BaseRuntimeGraph *runtime_graph)
27{
28 kernels::SISOKernel kernel(cur_op, runtime_graph);
29
30 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) ==
31 Tensor::element_type(kernel.output()))
32
33#ifndef DIS_QUANT
34 if (Tensor::element_type(kernel.input()) == DataType::U8)
35 {
36 LUCI_INTERPRETER_CHECK(Tensor::scale(kernel.output()) == 1. / 256);
37 }
38#endif // DIS_QUANT
39}
40
41void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
42{
43 kernels::SISOKernel kernel(cur_op, runtime_graph);
44
45 const auto input = kernel.input();
46 const auto output = kernel.output();
47
48 bool is_inplace = runtime_graph->is_inplace_op(cur_op);
49
50 const uint8_t *input_data = runtime_graph->getDataByTensor(input);
51 uint8_t *output_data = runtime_graph->getDataByTensor(output);
52
53 if (is_inplace)
54 {
55 output_data = const_cast<uint8_t *>(input_data);
56 }
57
58 assert(input_data != nullptr);
59 assert(output_data != nullptr);
60
61 const int flat_size = kernels::getTensorRuntimeShape(input, runtime_graph).flatSize();
62
63 switch (Tensor::element_type(input))
64 {
65#ifndef DIS_FLOAT
66 case DataType::FLOAT32:
67 luci_interpreter_pal::Logistic(flat_size, kernels::getTensorData<float>(input_data),
68 kernels::getTensorData<float>(output_data));
69 break;
70#endif // DIS_FLOAT
71#ifndef DIS_QUANT
72 case DataType::S8:
73 luci_interpreter_pal::Logistic(flat_size, kernels::getTensorData<int8_t>(input_data),
74 Tensor::scale(input), Tensor::zero_point(input),
75 kernels::getTensorData<int8_t>(output_data),
76 Tensor::scale(output), Tensor::zero_point(output));
77 break;
78#endif // DIS_QUANT
79 default:
80 assert(false && "Unsupported type.");
81 }
82
83 if (is_inplace)
84 {
85 runtime_graph->makeInplaceOperation(input, output);
86 }
87}
88
89} // namespace luci_interpreter
void makeInplaceOperation(const circle::Tensor *src_tensor, const circle::Tensor *dst_tensor)
bool is_inplace_op(const circle::Operator *op)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
Definition SISOKernel.h:47
const circle::Tensor * input() const
Definition SISOKernel.h:46
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
Definition Utils.cpp:29
void Logistic(const int flat_size, const float *input_data, float *output_data)
Definition PALGRU.h:26
void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Logistic.cpp:41
void configure_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Logistic.cpp:25