19#include "kernels/Utils.h"
45 switch (
input()->element_type())
47 case DataType::FLOAT32:
51 throw std::runtime_error(
"luci-rope Unsupported data type.");
55void RoPE::evalFloat()
const
62 const float *input_data = getTensorData<float>(
input());
63 const float *sin_table_data = getTensorData<float>(
sin_table());
64 const float *cos_table_data = getTensorData<float>(
cos_table());
65 float *output_data = getTensorData<float>(
output());
67 if (
params().mode == RoPEMode::GPT_NEOX)
69 const int32_t i0_n = input_shape.Dims(0);
70 const int32_t i1_n = input_shape.Dims(1);
71 const int32_t i2_n = input_shape.Dims(2);
72 const int32_t i3_n = input_shape.Dims(3);
74 for (int32_t i0 = 0; i0 < i0_n; ++i0)
76 for (int32_t i1 = 0; i1 < i1_n; ++i1)
78 for (int32_t i2 = 0; i2 < i2_n; ++i2)
80 for (int32_t i3 = 0; i3 < i3_n / 2; ++i3)
82 const int32_t
offset = tflite::Offset(input_shape, i0, i1, i2, i3);
83 const float x0 = input_data[
offset];
84 const float x1 = input_data[
offset + i3_n / 2];
86 output_data[
offset] = x0 * cos_table_data[i3] - x1 * sin_table_data[i3];
87 output_data[
offset + i3_n / 2] =
88 x0 * sin_table_data[i3 + i3_n / 2] + x1 * cos_table_data[i3 + i3_n / 2];
95 throw std::runtime_error(
"luci-intp RoPE unsupported mode.");
const RoPEParams & params() const
void resize(const Shape &new_shape)
RoPE(const Tensor *input, const Tensor *sin_table, const Tensor *cos_table, Tensor *output, const RoPEParams ¶ms)
void configure() override
const Tensor * cos_table() const
void execute() const override
const Tensor * sin_table() const
const Tensor * input() const
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)