ONE - On-device Neural Engine
Loading...
Searching...
No Matches
OneHot.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_ONEHOT_H__
19#define __NNFW_CKER_ONEHOT_H__
20
21#include "cker/Shape.h"
22
23namespace nnfw
24{
25namespace cker
26{
27
28template <typename T, typename TI>
29void OneHot(const int32_t depth, const T on_value, const T off_value, int32_t axis,
30 const Shape &indices_shape, const TI *indices_data, const Shape &, T *output_data)
31{
32 if (axis == -1)
33 axis = indices_shape.DimensionsCount();
34
35 // prefix_dim_size == # of elements before the axis
36 // depth == # of elements per axis
37 // suffix_dim_size == # of elements after the axis
38 int prefix_dim_size = 1;
39 for (int i = 0; i < axis; ++i)
40 {
41 prefix_dim_size *= indices_shape.Dims(i);
42 }
43 const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size;
44
45 // View the indices as a matrix of size:
46 // prefix_dim_size x suffix_dim_size
47 // View the output as a matrix of size:
48 // prefix_dim_size x depth x suffix_dim_size
49 // Then the output is:
50 // output(i, j, k) == (indices(i, k) == j) ? on : off
51 for (int i = 0; i < prefix_dim_size; ++i)
52 {
53 for (int j = 0; j < depth; ++j)
54 {
55 for (int k = 0; k < suffix_dim_size; ++k, ++output_data)
56 {
57 *output_data =
58 static_cast<int>(indices_data[i * suffix_dim_size + k]) == j ? on_value : off_value;
59 }
60 }
61 }
62}
63
64} // namespace cker
65} // namespace nnfw
66
67#endif // __NNFW_CKER_ONEHOT_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
int FlatSize() const
Definition Shape.h:181
void OneHot(const int32_t depth, const T on_value, const T off_value, int32_t axis, const Shape &indices_shape, const TI *indices_data, const Shape &, T *output_data)
Definition OneHot.h:29
Definition topk_v2.h:30