ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::RoPE Class Reference

#include <RoPE.h>

Collaboration diagram for luci_interpreter::kernels::RoPE:

Public Member Functions

 RoPE (const Tensor *input, const Tensor *sin_table, const Tensor *cos_table, Tensor *output, const RoPEParams &params)
 
const Tensorinput () const
 
const Tensorsin_table () const
 
const Tensorcos_table () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< RoPEParams >
const RoPEParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< RoPEParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const RoPEParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< RoPEParams >
const RoPEParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file RoPE.h.

Constructor & Destructor Documentation

◆ RoPE()

luci_interpreter::kernels::RoPE::RoPE ( const Tensor input,
const Tensor sin_table,
const Tensor cos_table,
Tensor output,
const RoPEParams params 
)

Definition at line 26 of file RoPE.cpp.

28 : KernelWithParams<RoPEParams>({input, sin_table, cos_table}, {output}, params)
29{
30}
Tensor * output() const
Definition RoPE.h:37
const Tensor * cos_table() const
Definition RoPE.h:36
const Tensor * sin_table() const
Definition RoPE.h:35
const Tensor * input() const
Definition RoPE.h:34

References cos_table(), input(), and sin_table().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::RoPE::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 32 of file RoPE.cpp.

33{
34 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() == 4);
35 LUCI_INTERPRETER_CHECK(sin_table()->shape().dim(3) == input()->shape().dim(3));
36 LUCI_INTERPRETER_CHECK(cos_table()->shape().dim(3) == input()->shape().dim(3));
37
38 LUCI_INTERPRETER_CHECK(params().mode == RoPEMode::GPT_NEOX);
39
40 output()->resize(input()->shape());
41}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36

References cos_table(), input(), LUCI_INTERPRETER_CHECK, output(), luci_interpreter::KernelWithParams< RoPEParams >::params(), luci_interpreter::Tensor::resize(), and sin_table().

◆ cos_table()

const Tensor * luci_interpreter::kernels::RoPE::cos_table ( ) const
inline

Definition at line 36 of file RoPE.h.

36{ return _inputs[2]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and RoPE().

◆ execute()

void luci_interpreter::kernels::RoPE::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 43 of file RoPE.cpp.

44{
45 switch (input()->element_type())
46 {
47 case DataType::FLOAT32:
48 evalFloat();
49 break;
50 default:
51 throw std::runtime_error("luci-rope Unsupported data type.");
52 }
53}

References input().

◆ input()

const Tensor * luci_interpreter::kernels::RoPE::input ( ) const
inline

Definition at line 34 of file RoPE.h.

34{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and RoPE().

◆ output()

Tensor * luci_interpreter::kernels::RoPE::output ( ) const
inline

Definition at line 37 of file RoPE.h.

37{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ sin_table()

const Tensor * luci_interpreter::kernels::RoPE::sin_table ( ) const
inline

Definition at line 35 of file RoPE.h.

35{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and RoPE().


The documentation for this class was generated from the following files: