18#include "kernels/OneHot.h"
19#include "kernels/Utils.h"
48 for (int32_t
i = 0;
i < axis; ++
i)
62 for (int32_t
j = 0;
j < depth; ++
j)
99 for (int32_t d = 0; d < axis; ++d)
104 for (int32_t d = axis + 1; d <
output_shape.num_dims(); ++d)
117 switch (
output()->element_type())
119 case loco::DataType::FLOAT32:
122 case loco::DataType::U8:
125 case loco::DataType::S16:
130 throw std::runtime_error(
"Not supported, yet!");
const OneHotParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * on_value() const
OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value, const Tensor *off_value, Tensor *output, const OneHotParams ¶ms)
const Tensor * off_value() const
const Tensor * depth() const
void execute() const override
const Tensor * indices() const
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
T must_cast(loco::Node *node)