ONE - On-device Neural Engine
Loading...
Searching...
No Matches
topk_v2.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
24#ifndef __NNFW_RT_OPTIMIZED_OPS_TOPK_V2_H__
25#define __NNFW_RT_OPTIMIZED_OPS_TOPK_V2_H__
26
27typedef int32_t int32;
28
29namespace nnfw
30{
31namespace rt
32{
34{
46template <typename T> class TopContainer
47{
48public:
52 TopContainer() = delete;
58 TopContainer(int32 k, int32 row_size) : k_(k), container_(), values_(nullptr)
59 {
60 container_.reserve(std::min(k, row_size) + 1);
61 }
62
67 TopContainer(const TopContainer &) = delete;
68 /*
69 * @brief Prevent instances of this class from being copied (As this class contains pointers)
70 * @param [in] topContainer To copy
71 * @return Reference of TopContainer
72 */
74
80 void start_collecting(const T *values)
81 {
82 values_ = values;
83 container_.clear();
84 }
85
91 void push(int32 a)
92 {
93 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
94 if (container_.size() <= (size_t)k_)
95 {
96 container_.push_back(a);
97 if (container_.size() == (size_t)(k_ + 1))
98 {
99 std::make_heap(container_.begin(), container_.end(), comparator);
100 std::pop_heap(container_.begin(), container_.end(), comparator);
101 }
102 }
103 else if (comparator(a, container_.front()))
104 {
105 container_.back() = a;
106 std::push_heap(container_.begin(), container_.end(), comparator);
107 std::pop_heap(container_.begin(), container_.end(), comparator);
108 }
109 }
110
115 const std::vector<int32> &sorted_result()
116 {
117 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
118 if (container_.size() <= (size_t)(k_))
119 {
120 std::sort(container_.begin(), container_.end(), comparator);
121 }
122 else
123 {
124 std::sort_heap(container_.begin(), container_.end() - 1, comparator);
125 container_.resize(k_);
126 }
127 return container_;
128 }
129
130private:
131 int32 k_;
132 std::vector<int32> container_;
133 const T *values_ = nullptr;
134
135 bool compare_fun(int32 a, int32 b) const
136 {
137 if (values_[b] < values_[a])
138 {
139 return true;
140 }
141 else if (values_[b] > values_[a])
142 {
143 return false;
144 }
145 else
146 {
147 return a < b;
148 }
149 }
150};
151
162template <typename T>
163void TopK(int32 row_size, int32 num_rows, const T *data, int32 k, int32 *output_indexes,
164 T *output_values)
165{
166 TopContainer<T> topc(k, row_size);
167 for (int row = 0; row < num_rows; ++row)
168 {
169 const T *values_row = data + row * row_size;
170 topc.start_collecting(values_row);
171 for (int32 c = 0; c < row_size; ++c)
172 {
173 topc.push(c);
174 }
175
176 // Prepare output buffers.
177 int32 *indexes_row = output_indexes + row * k;
178 T *output_row = output_values + row * k;
179 // We always assume that the output is sorted.
180 const auto &top_k = topc.sorted_result();
181 std::copy(top_k.begin(), top_k.end(), indexes_row);
182 std::transform(top_k.begin(), top_k.end(), output_row,
183 [values_row](const int32 loc) { return values_row[loc]; });
184 }
185}
186
187} // namespace optimized_ops
188} // namespace rt
189} // namespace nnfw
190
191#endif // __NNFW_RT_OPTIMIZED_OPS_TOPK_V2_H__
class to define TopK operation
Definition topk_v2.h:47
TopContainer(const TopContainer &)=delete
Prevent instances of this class from being copied (As this class contains pointers)
void push(int32 a)
Push a value to be compared for topk.
Definition topk_v2.h:91
TopContainer(int32 k, int32 row_size)
Constructor with params.
Definition topk_v2.h:58
TopContainer & operator=(const TopContainer &)=delete
TopContainer()=delete
Prevent default constructor of of this class.
const std::vector< int32 > & sorted_result()
Get sorted result from pushed values.
Definition topk_v2.h:115
void start_collecting(const T *values)
Start collecting.
Definition topk_v2.h:80
void TopK(int32 row_size, int32 num_rows, const T *data, int32 k, int32 *output_indexes, T *output_values)
Operates TopK operation with params.
Definition topk_v2.h:163
Definition topk_v2.h:30
int32_t int32
Definition topk_v2.h:27