18#include "kernels/OneHot.h"
19#include "kernels/Utils.h"
30void OneHotComputeImpl(
const Tensor *indices_tensor,
const Tensor *on_value_tensor,
31 const Tensor *off_value_tensor, int32_t depth, int32_t axis,
35 auto const &input_shape = indices_tensor->shape();
36 axis = axis == -1 ? input_shape.num_dims() : axis;
39 auto const *indices = getTensorData<int32_t>(indices_tensor);
40 auto const on_value = getTensorData<T>(on_value_tensor)[0];
41 auto const off_value = getTensorData<T>(off_value_tensor)[0];
42 auto *
output = getTensorData<T>(output_tensor);
47 auto prefix_dim_size = 1;
48 for (int32_t i = 0; i < axis; ++i)
50 prefix_dim_size *= input_shape.dim(i);
52 assert(prefix_dim_size > 0);
53 auto const suffix_dim_size = input_shape.num_elements() / prefix_dim_size;
61 for (int32_t i = 0; i < prefix_dim_size; ++i)
62 for (int32_t j = 0; j < depth; ++j)
63 for (int32_t k = 0; k < suffix_dim_size; ++k, ++
output)
64 *output = indices[i * suffix_dim_size + k] == j ? on_value : off_value;
70 const Tensor *off_value,
Tensor *output,
const OneHotParams ¶ms)
71 : KernelWithParams<OneHotParams>({indices, depth, on_value, off_value}, {
output}, params)
76void OneHot::configure()
91 auto const depth_value = getTensorData<int32_t>(depth())[0];
92 auto const &input_shape = indices()->shape();
93 auto const input_dims = input_shape.num_dims();
94 auto const axis = params().axis == -1 ? input_dims : params().axis;
99 for (int32_t d = 0; d < axis; ++d)
104 for (int32_t d = axis + 1; d <
output_shape.num_dims(); ++d)
112void OneHot::execute()
const
114 auto const depth_value = getTensorData<int32_t>(depth())[0];
115 auto const axis = params().axis;
117 switch (
output()->element_type())
119 case DataType::FLOAT32:
120 OneHotComputeImpl<float>(indices(), on_value(), off_value(), depth_value, axis,
output());
123 OneHotComputeImpl<uint8_t>(indices(), on_value(), off_value(), depth_value, axis,
output());
126 OneHotComputeImpl<int16_t>(indices(), on_value(), off_value(), depth_value, axis,
output());
130 assert(
false &&
"Not supported, yet!");
OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.