24#ifndef __NNFW_RT_OPTIMIZED_OPS_TOPK_V2_H__
25#define __NNFW_RT_OPTIMIZED_OPS_TOPK_V2_H__
60 container_.reserve(std::min(k, row_size) + 1);
93 auto comparator = [
this](
int32 a,
int32 b) {
return compare_fun(a, b); };
94 if (container_.size() <= (
size_t)k_)
96 container_.push_back(a);
97 if (container_.size() == (
size_t)(k_ + 1))
99 std::make_heap(container_.begin(), container_.end(), comparator);
100 std::pop_heap(container_.begin(), container_.end(), comparator);
103 else if (comparator(a, container_.front()))
105 container_.back() = a;
106 std::push_heap(container_.begin(), container_.end(), comparator);
107 std::pop_heap(container_.begin(), container_.end(), comparator);
117 auto comparator = [
this](
int32 a,
int32 b) {
return compare_fun(a, b); };
118 if (container_.size() <= (
size_t)(k_))
120 std::sort(container_.begin(), container_.end(), comparator);
124 std::sort_heap(container_.begin(), container_.end() - 1, comparator);
125 container_.resize(k_);
132 std::vector<int32> container_;
133 const T *values_ =
nullptr;
137 if (values_[b] < values_[a])
141 else if (values_[b] > values_[a])
167 for (
int row = 0; row < num_rows; ++row)
169 const T *values_row = data + row * row_size;
171 for (
int32 c = 0; c < row_size; ++c)
177 int32 *indexes_row = output_indexes + row * k;
178 T *output_row = output_values + row * k;
181 std::copy(top_k.begin(), top_k.end(), indexes_row);
182 std::transform(top_k.begin(), top_k.end(), output_row,
183 [values_row](
const int32 loc) { return values_row[loc]; });
class to define TopK operation
TopContainer(const TopContainer &)=delete
Prevent instances of this class from being copied (As this class contains pointers)
void push(int32 a)
Push a value to be compared for topk.
TopContainer(int32 k, int32 row_size)
Constructor with params.
TopContainer & operator=(const TopContainer &)=delete
TopContainer()=delete
Prevent default constructor of of this class.
const std::vector< int32 > & sorted_result()
Get sorted result from pushed values.
void start_collecting(const T *values)
Start collecting.
void TopK(int32 row_size, int32 num_rows, const T *data, int32 k, int32 *output_indexes, T *output_values)
Operates TopK operation with params.