ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TopKV2.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_TOPK_V2_H__
19#define __NNFW_CKER_TOPK_V2_H__
20
21#include "cker/Shape.h"
22
23namespace nnfw::cker
24{
25
26template <typename T, typename Tidx> class TopContainer
27{
28public:
29 TopContainer() = delete;
30 TopContainer(uint32_t k, uint32_t row_size) : k_(k)
31 {
32 container_.reserve(std::min(k, row_size) + 1);
33 }
34
35 void start_collecting(const T *values)
36 {
37 values_ = values;
38 container_.clear();
39 is_heap_ = false;
40 }
41
42 void push(Tidx a)
43 {
44 auto comparator = [this](Tidx a, Tidx b) { return compare_fun(a, b); };
45 if (!is_heap_)
46 {
47 container_.push_back(a);
48 if (container_.size() == k_ + 1)
49 {
50 std::make_heap(container_.begin(), container_.end(), comparator);
51 std::pop_heap(container_.begin(), container_.end(), comparator);
52 container_.pop_back();
53 is_heap_ = true;
54 }
55 }
56 else if (comparator(a, container_.front()))
57 {
58 // Due to how we defined comparator / compare_fun, container_.front()
59 // contains the index of the smallest of the top-k elements seen so far.
60 //
61 // If control reaches this point, we know that the current index a
62 // corresponds to an element which is bigger than the smallest of the
63 // top-k elements seen so far. Hence, we have to update the indices of
64 // the top-k elements, by removing the index of the smallest top-k
65 // element, adding a, and making sure container_[0:k] is still a heap.
66 std::pop_heap(container_.begin(), container_.end(), comparator);
67 container_.back() = a;
68 std::push_heap(container_.begin(), container_.end(), comparator);
69 }
70 }
71
72 const std::vector<Tidx> &sorted_result()
73 {
74 auto comparator = [this](Tidx a, Tidx b) { return compare_fun(a, b); };
75 if (!is_heap_)
76 {
77 // Note: due to the way we defined compare_fun (see comments for that
78 // function) std::sort puts the indices from container_ in decreasing
79 // order of the corresponding elements.
80 std::sort(container_.begin(), container_.end(), comparator);
81 }
82 else
83 {
84 std::sort_heap(container_.begin(), container_.end(), comparator);
85 }
86 return container_;
87 }
88
89private:
90 const uint32_t k_;
91
92 // container_[0,k) holds the indices of the largest k elements from values_
93 // seen so far. If more than k elements are pushed, then elements are
94 // maintained in a min-heap order: container_.front() is
95 // the index of the smallest of the top-k elements see so far.
96 std::vector<Tidx> container_;
97
98 // Once more than k elements are pushed, the container becomes a min heap,
99 // and is_heap_ becomes true.
100 bool is_heap_ = false;
101
102 const T *values_ = nullptr;
103
104 // Compares indices a and b based on the corresponding elements from values_.
105 //
106 // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a]
107 // (notice the inversion of direction, not a typo); ties (==) are broken in
108 // favor of earlier elements (i.e., a < b).
109 bool compare_fun(Tidx a, Tidx b) const
110 {
111 if (values_[b] < values_[a])
112 {
113 return true;
114 }
115 else if (values_[b] > values_[a])
116 {
117 return false;
118 }
119 else
120 {
121 return a < b;
122 }
123 }
124};
125
126template <typename T, typename Tidx = int32_t>
127inline void TopKV2(const Shape &input_shape, const T *input_data, const uint32_t k,
128 T *output_value_data, Tidx *output_indices_data)
129{
130 const int32_t row_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
131 int32_t num_rows = 1;
132 for (int32_t i = 0; i < input_shape.DimensionsCount() - 1; ++i)
133 {
134 num_rows *= input_shape.Dims(i);
135 }
136
137 TopContainer<T, Tidx> topc(k, row_size);
138 for (int32_t row = 0; row < num_rows; ++row)
139 {
140 const T *values_row = input_data + row * row_size;
141 topc.start_collecting(values_row);
142 for (int32_t c = 0; c < row_size; ++c)
143 {
144 topc.push(c);
145 }
146
147 // Prepare output buffers.
148 Tidx *indexes_row = output_indices_data + row * k;
149 T *output_row = output_value_data + row * k;
150 // We always assume that the output is sorted.
151 const auto &top_k = topc.sorted_result();
152 std::copy(top_k.begin(), top_k.end(), indexes_row);
153 std::transform(top_k.begin(), top_k.end(), output_row,
154 [values_row](const int32_t loc) { return values_row[loc]; });
155 }
156}
157
158} // namespace nnfw::cker
159
160#endif // __NNFW_CKER_TOPK_V2_H__
int32_t DimensionsCount() const
Definition Shape.h:103
int32_t Dims(int i) const
Definition Shape.h:106
TopContainer(uint32_t k, uint32_t row_size)
Definition TopKV2.h:30
void start_collecting(const T *values)
Definition TopKV2.h:35
const std::vector< Tidx > & sorted_result()
Definition TopKV2.h:72
void push(Tidx a)
Definition TopKV2.h:42
void TopKV2(const Shape &input_shape, const T *input_data, const uint32_t k, T *output_value_data, Tidx *output_indices_data)
Definition TopKV2.h:127