ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALFullyConnected.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
18#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
19
20#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
21#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
22#include <arm_nnfunctions.h>
23
25{
26template <typename T>
27static inline void FullyConnected(const tflite::FullyConnectedParams &params,
28 const tflite::RuntimeShape &input_shape, const T *input_data,
29 const tflite::RuntimeShape &filter_shape, const T *filter_data,
30 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
31 const tflite::RuntimeShape &output_shape, T *output_data)
32{
33 {
34 // MARK: At this moment this operation doesn't support
35 assert(false && "FullyConnected NYI");
36 (void)params;
37 (void)input_shape;
38 (void)input_data;
39 (void)filter_shape;
40 (void)filter_data;
41 (void)bias_shape;
42 (void)bias_data;
43 (void)output_shape;
44 (void)output_data;
45 }
46}
47
48template <>
49inline void
50FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
51 const tflite::RuntimeShape &input_shape, const int8_t *input_data,
52 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
53 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
54 const tflite::RuntimeShape &output_shape, int8_t *output_data)
55{
56 assert(output_shape.DimensionsCount() == 2);
57
58 const int batches = output_shape.Dims(0);
59 const int output_depth = output_shape.Dims(1);
60
61 const int filter_dim_count = filter_shape.DimensionsCount();
62 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
63
64 cmsis_nn_fc_params fc_params;
65 fc_params.input_offset = params.input_offset;
66 fc_params.output_offset = params.output_offset;
67 fc_params.filter_offset = params.weights_offset;
68 fc_params.activation.min = params.quantized_activation_min;
69 fc_params.activation.max = params.quantized_activation_max;
70
71 cmsis_nn_per_tensor_quant_params quant_params;
72 quant_params.multiplier = params.output_multiplier;
73 quant_params.shift = params.output_shift;
74
75 cmsis_nn_dims input_dims;
76 input_dims.n = batches;
77 input_dims.h = 1;
78 input_dims.w = 1;
79 input_dims.c = accum_depth;
80
81 cmsis_nn_dims filter_dims;
82 filter_dims.n = accum_depth;
83 filter_dims.h = 1;
84 filter_dims.w = 1;
85 filter_dims.c = output_depth;
86
87 cmsis_nn_dims bias_dims;
88 bias_dims.n = 1;
89 bias_dims.h = 1;
90 bias_dims.w = 1;
91 bias_dims.c = output_depth;
92
93 cmsis_nn_dims output_dims;
94 output_dims.n = batches;
95 output_dims.h = 1;
96 output_dims.w = 1;
97 output_dims.c = output_depth;
98
99 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
100 auto buffer = std::make_unique<int8_t[]>(buf_size);
101 assert(buffer != nullptr);
102
103 cmsis_nn_context ctx;
104 ctx.buf = buffer.get();
105 ctx.size = buf_size;
106
107 auto res =
108 arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
109 filter_data, &bias_dims, bias_data, &output_dims, output_data);
110 assert(res == ARM_MATH_SUCCESS);
111}
112} // namespace luci_interpreter_pal
113
114#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
const luci_interpreter::RuntimeShape output_shape
void FullyConnected< int8_t >(const tflite::FullyConnectedParams &params, const tflite::RuntimeShape &input_shape, const int8_t *input_data, const tflite::RuntimeShape &filter_shape, const int8_t *filter_data, const tflite::RuntimeShape &bias_shape, const int32_t *bias_data, const tflite::RuntimeShape &output_shape, int8_t *output_data)
void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape, const float *input_data, const Shape &weights_shape, const float *weights_data, const Shape &, const float *bias_data, const Shape &, float *output_data)