18#include "kernels/OneHot.h"
19#include "kernels/Utils.h"
30void OneHotComputeImpl(
const Tensor *indices_tensor,
const Tensor *on_value_tensor,
31 const Tensor *off_value_tensor, int32_t depth, int32_t axis,
35 auto const &input_shape = indices_tensor->shape();
36 axis = axis == -1 ? input_shape.num_dims() : axis;
39 auto const *indices = getTensorData<int32_t>(indices_tensor);
40 auto const on_value = getTensorData<T>(on_value_tensor)[0];
41 auto const off_value = getTensorData<T>(off_value_tensor)[0];
42 auto *
output = getTensorData<T>(output_tensor);
47 auto prefix_dim_size = 1;
48 for (int32_t i = 0; i < axis; ++i)
50 prefix_dim_size *= input_shape.dim(i);
52 assert(prefix_dim_size > 0);
53 auto const suffix_dim_size = input_shape.num_elements() / prefix_dim_size;
61 for (int32_t i = 0; i < prefix_dim_size; ++i)
62 for (int32_t j = 0; j < depth; ++j)
63 for (int32_t k = 0; k < suffix_dim_size; ++k, ++
output)
64 *output = indices[i * suffix_dim_size + k] == j ? on_value : off_value;
91 auto const depth_value = getTensorData<int32_t>(
depth())[0];
93 auto const input_dims = input_shape.
num_dims();
99 for (int32_t d = 0; d < axis; ++d)
104 for (int32_t d = axis + 1; d <
output_shape.num_dims(); ++d)
114 auto const depth_value = getTensorData<int32_t>(
depth())[0];
117 switch (
output()->element_type())
119 case loco::DataType::FLOAT32:
122 case loco::DataType::U8:
125 case loco::DataType::S16:
130 throw std::runtime_error(
"Not supported, yet!");
const OneHotParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * on_value() const
OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams ¶ms)
const Tensor * off_value() const
const Tensor * depth() const
void execute() const override
const Tensor * indices() const
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape