19#include "kernels/Utils.h"
45 switch (
input()->element_type())
47 case DataType::FLOAT32:
51 throw std::runtime_error(
"luci-rope Unsupported data type.");
55void RoPE::evalFloat()
const
67 if (
params().mode == RoPEMode::GPT_NEOX)
84 const float x1 = input_data[
offset +
i3_n / 2];
95 throw std::runtime_error(
"luci-intp RoPE unsupported mode.");
const RoPEParams & params() const
void resize(const Shape &new_shape)
RoPE(const Tensor *input, const Tensor *sin_table, const Tensor *cos_table, Tensor *output, const RoPEParams ¶ms)
void configure() override
const Tensor * cos_table() const
void execute() const override
const Tensor * sin_table() const
const Tensor * input() const
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)