ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::RmsNorm Class Reference

#include <RmsNorm.h>

Collaboration diagram for luci_interpreter::kernels::RmsNorm:

Public Member Functions

 RmsNorm (const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams &params)
 
const Tensorinput () const
 
const Tensorgamma () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< RmsNormParams >
const RmsNormParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< RmsNormParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const RmsNormParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< RmsNormParams >
const RmsNormParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file RmsNorm.h.

Constructor & Destructor Documentation

◆ RmsNorm()

luci_interpreter::kernels::RmsNorm::RmsNorm ( const Tensor input,
const Tensor gamma,
Tensor output,
const RmsNormParams params 
)

Definition at line 29 of file RmsNorm.cpp.

31 : KernelWithParams<RmsNormParams>({input, gamma}, {output}, params)
32{
33}
const RmsNormParams & params() const
Definition Kernel.h:67
const Tensor * gamma() const
Definition RmsNorm.h:34
const Tensor * input() const
Definition RmsNorm.h:33

References gamma(), and input().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::RmsNorm::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 35 of file RmsNorm.cpp.

36{
37 auto num_dims = input()->shape().num_dims();
38 LUCI_INTERPRETER_CHECK(num_dims == 3 || num_dims == 4);
39 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
40 LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type());
41 LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1);
42 LUCI_INTERPRETER_CHECK((gamma()->shape().dim(0) == input()->shape().dim(num_dims - 1)) ||
43 (gamma()->shape().dim(0) == 1));
44
45 output()->resize(input()->shape());
46}
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36

References gamma(), input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::RmsNorm::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 48 of file RmsNorm.cpp.

49{
50 switch (input()->element_type())
51 {
52 case DataType::FLOAT32:
53 evalFloat();
54 break;
55 default:
56 throw std::runtime_error("luci-intp RmsNorm Unsupported type.");
57 }
58}

References input().

◆ gamma()

const Tensor * luci_interpreter::kernels::RmsNorm::gamma ( ) const
inline

Definition at line 34 of file RmsNorm.h.

34{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and RmsNorm().

◆ input()

const Tensor * luci_interpreter::kernels::RmsNorm::input ( ) const
inline

Definition at line 33 of file RmsNorm.h.

33{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and RmsNorm().

◆ output()

Tensor * luci_interpreter::kernels::RmsNorm::output ( ) const
inline

Definition at line 35 of file RmsNorm.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: