18#include "kernels/OneHot.h"
19#include "kernels/Utils.h"
48 for (int32_t
i = 0;
i < axis; ++
i)
62 for (int32_t
j = 0;
j < depth; ++
j)
70 const Tensor *off_value,
Tensor *output,
const OneHotParams ¶ms)
71 : KernelWithParams<OneHotParams>({indices, depth, on_value, off_value}, {
output}, params)
76void OneHot::configure()
91 auto const depth_value = getTensorData<int32_t>(depth())[0];
92 auto const &input_shape = indices()->shape();
93 auto const input_dims = input_shape.num_dims();
94 auto const axis = params().axis == -1 ? input_dims : params().axis;
99 for (int32_t d = 0; d < axis; ++d)
104 for (int32_t d = axis + 1; d <
output_shape.num_dims(); ++d)
112void OneHot::execute()
const
114 auto const depth_value = getTensorData<int32_t>(depth())[0];
115 auto const axis = params().axis;
117 switch (
output()->element_type())
119 case DataType::FLOAT32:
120 OneHotComputeImpl<float>(indices(), on_value(), off_value(), depth_value, axis,
output());
123 OneHotComputeImpl<uint8_t>(indices(), on_value(), off_value(), depth_value, axis,
output());
126 OneHotComputeImpl<int16_t>(indices(), on_value(), off_value(), depth_value, axis,
output());
130 assert(
false &&
"Not supported, yet!");
OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
T must_cast(loco::Node *node)
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.