19#include "kernels/Utils.h"
27template <
typename InT,
typename OutT>
28void cast_data(
const InT *in_data, OutT *out_data, uint32_t elements_count)
30 std::transform(in_data, in_data + elements_count, out_data,
31 [](InT a) {
return static_cast<OutT
>(a); });
34template <
typename InT>
void cast_from_pointer_to_tensor(
const InT *in_data,
Tensor *out_tensor)
41 case loco::DataType::U8:
42 cast_data(in_data, getTensorData<uint8_t>(out_tensor), elements_count);
44 case loco::DataType::U16:
45 cast_data(in_data, getTensorData<uint16_t>(out_tensor), elements_count);
47 case loco::DataType::U32:
48 cast_data(in_data, getTensorData<uint32_t>(out_tensor), elements_count);
50 case loco::DataType::U64:
51 cast_data(in_data, getTensorData<uint64_t>(out_tensor), elements_count);
53 case loco::DataType::S8:
54 cast_data(in_data, getTensorData<int8_t>(out_tensor), elements_count);
56 case loco::DataType::S16:
57 cast_data(in_data, getTensorData<int16_t>(out_tensor), elements_count);
59 case loco::DataType::S32:
60 cast_data(in_data, getTensorData<int32_t>(out_tensor), elements_count);
62 case loco::DataType::S64:
63 cast_data(in_data, getTensorData<int64_t>(out_tensor), elements_count);
65 case loco::DataType::FLOAT32:
66 cast_data(in_data, getTensorData<float>(out_tensor), elements_count);
68 case loco::DataType::BOOL:
69 cast_data(in_data, getTensorData<bool>(out_tensor), elements_count);
72 throw std::runtime_error(
"Unsupported output type.");
76void cast_from_tensor_to_tensor(
const Tensor *in_tensor,
Tensor *out_tensor)
82 case loco::DataType::U8:
83 cast_from_pointer_to_tensor(getTensorData<uint8_t>(in_tensor), out_tensor);
85 case loco::DataType::U16:
86 cast_from_pointer_to_tensor(getTensorData<uint16_t>(in_tensor), out_tensor);
88 case loco::DataType::U32:
89 cast_from_pointer_to_tensor(getTensorData<uint32_t>(in_tensor), out_tensor);
91 case loco::DataType::U64:
92 cast_from_pointer_to_tensor(getTensorData<uint64_t>(in_tensor), out_tensor);
94 case loco::DataType::S8:
95 cast_from_pointer_to_tensor(getTensorData<int8_t>(in_tensor), out_tensor);
97 case loco::DataType::S16:
98 cast_from_pointer_to_tensor(getTensorData<int16_t>(in_tensor), out_tensor);
100 case loco::DataType::S32:
101 cast_from_pointer_to_tensor(getTensorData<int32_t>(in_tensor), out_tensor);
103 case loco::DataType::S64:
104 cast_from_pointer_to_tensor(getTensorData<int64_t>(in_tensor), out_tensor);
106 case loco::DataType::FLOAT32:
107 cast_from_pointer_to_tensor(getTensorData<float>(in_tensor), out_tensor);
109 case loco::DataType::BOOL:
110 cast_from_pointer_to_tensor(getTensorData<bool>(in_tensor), out_tensor);
113 throw std::runtime_error(
"Unsupported input type.");
137 assert(
input()->shape().num_elements() ==
output()->shape().num_elements());
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
DataType element_type() const
Cast(const Tensor *input, Tensor *output)
void configure() override
const Tensor * input() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)