17#ifndef __NNFW_CKER_ROPE_H__
18#define __NNFW_CKER_ROPE_H__
33 const Shape &sin_table_shape,
const T *sin_table_data,
37 if (input_shape.
Dims(3) != sin_table_shape.
Dims(3))
38 throw std::runtime_error(
"the dimension(3) of input and sin_table do not match");
40 if (input_shape.
Dims(3) != cos_table_shape.
Dims(3))
41 throw std::runtime_error(
"the dimension(3) of input and cos_table do not match");
49 throw std::runtime_error(
"i3_n must be even number");
53 for (int32_t i0 = 0; i0 < i0_n; ++i0)
55 for (int32_t i1 = 0; i1 < i1_n; ++i1)
57 for (int32_t i2 = 0; i2 < i2_n; ++i2)
59 for (int32_t i3 = 0; i3 < i3_n / 2; ++i3)
61 const int32_t
offset =
Offset(input_shape, i0, i1, i2, i3);
62 const T x0 = input_data[
offset];
63 const T x1 = input_data[
offset + i3_n / 2];
65 output_data[
offset] = x0 * cos_table_data[i3] - x1 * sin_table_data[i3];
66 output_data[
offset + i3_n / 2] =
67 x0 * sin_table_data[i3 + i3_n / 2] + x1 * cos_table_data[i3 + i3_n / 2];
75 throw std::runtime_error(
"Unsupported RoPE mode");
int32_t Dims(int i) const
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
void RoPE(const RoPEMode mode, const Shape &input_shape, const T *input_data, const Shape &sin_table_shape, const T *sin_table_data, const Shape &cos_table_shape, const T *cos_table_data, const Shape &output_shape, T *output_data)