ONE - On-device Neural Engine
Loading...
Searching...
No Matches
OneHot.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/OneHot.h"
19#include "kernels/Utils.h"
20
21namespace luci_interpreter
22{
23namespace kernels
24{
25
26namespace
27{
28
29template <typename T>
30void OneHotComputeImpl(const Tensor *indices_tensor, const Tensor *on_value_tensor,
31 const Tensor *off_value_tensor, int32_t depth, int32_t axis,
32 Tensor *output_tensor)
33{
34 // define input shape and correct axis
35 auto const &input_shape = indices_tensor->shape();
36 axis = axis == -1 ? input_shape.num_dims() : axis;
37
38 // TODO support other integer input types
39 auto const *indices = getTensorData<int32_t>(indices_tensor);
40 auto const on_value = getTensorData<T>(on_value_tensor)[0];
41 auto const off_value = getTensorData<T>(off_value_tensor)[0];
42 auto *output = getTensorData<T>(output_tensor);
43
44 // prefix_dim_size == # of elements before the axis
45 // depth == # of elements per axis
46 // suffix_dim_size == # of elements after the axis
47 auto prefix_dim_size = 1;
48 for (int32_t i = 0; i < axis; ++i)
49 {
50 prefix_dim_size *= input_shape.dim(i);
51 }
52 assert(prefix_dim_size > 0);
53 auto const suffix_dim_size = input_shape.num_elements() / prefix_dim_size;
54
55 // View the indices as a matrix of size:
56 // prefix_dim_size x suffix_dim_size
57 // View the output as a matrix of size:
58 // prefix_dim_size x depth x suffix_dim_size
59 // Then the output is:
60 // output(i, j, k) == (indices(i, k) == j) ? on : off
61 for (int32_t i = 0; i < prefix_dim_size; ++i)
62 for (int32_t j = 0; j < depth; ++j)
63 for (int32_t k = 0; k < suffix_dim_size; ++k, ++output)
64 *output = indices[i * suffix_dim_size + k] == j ? on_value : off_value;
65}
66
67} // namespace
68
69OneHot::OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value,
70 const Tensor *off_value, Tensor *output, const OneHotParams &params)
72{
73 // Do nothing
74}
75
77{
78 // check types
79 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32);
80 LUCI_INTERPRETER_CHECK(depth()->element_type() == DataType::S32);
81 LUCI_INTERPRETER_CHECK(on_value()->element_type() == off_value()->element_type());
82 LUCI_INTERPRETER_CHECK(output()->element_type() == on_value()->element_type());
83
84 // check shape dependent parameters
85 LUCI_INTERPRETER_CHECK(on_value()->shape().num_elements() == 1);
86 LUCI_INTERPRETER_CHECK(off_value()->shape().num_elements() == 1);
87 LUCI_INTERPRETER_CHECK(depth()->shape().num_elements() == 1);
88 LUCI_INTERPRETER_CHECK(params().axis >= -1 && params().axis <= indices()->shape().num_dims());
89
90 // define parameters that affect the output shape
91 auto const depth_value = getTensorData<int32_t>(depth())[0];
92 auto const &input_shape = indices()->shape();
93 auto const input_dims = input_shape.num_dims();
94 auto const axis = params().axis == -1 ? input_dims : params().axis;
95
96 // define output shape
97 Shape output_shape(input_shape.num_dims() + 1);
98 {
99 for (int32_t d = 0; d < axis; ++d)
100 output_shape.dim(d) = input_shape.dim(d);
101
102 output_shape.dim(axis) = depth_value;
103
104 for (int32_t d = axis + 1; d < output_shape.num_dims(); ++d)
105 output_shape.dim(d) = input_shape.dim(d - 1);
106 }
107
108 // reshape output
110}
111
112void OneHot::execute() const
113{
114 auto const depth_value = getTensorData<int32_t>(depth())[0];
115 auto const axis = params().axis;
116
117 switch (output()->element_type())
118 {
119 case loco::DataType::FLOAT32:
120 OneHotComputeImpl<float>(indices(), on_value(), off_value(), depth_value, axis, output());
121 break;
122 case loco::DataType::U8:
123 OneHotComputeImpl<uint8_t>(indices(), on_value(), off_value(), depth_value, axis, output());
124 break;
125 case loco::DataType::S16:
126 OneHotComputeImpl<int16_t>(indices(), on_value(), off_value(), depth_value, axis, output());
127 break;
128 default:
129 // TODO Support other data types
130 throw std::runtime_error("Not supported, yet!");
131 break;
132 }
133}
134
135} // namespace kernels
136} // namespace luci_interpreter
const OneHotParams & params() const
Definition Kernel.h:67
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const Tensor * on_value() const
Definition OneHot.h:36
OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams &params)
Definition OneHot.cpp:69
const Tensor * off_value() const
Definition OneHot.h:37
const Tensor * depth() const
Definition OneHot.h:35
void execute() const override
Definition OneHot.cpp:112
const Tensor * indices() const
Definition OneHot.h:34
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape