18#include "kernels/Utils.h"
33 if (
input()->element_type() !=
output()->element_type())
35 throw std::runtime_error(
"Input/output tensor data type mismatch.");
42 switch (
input()->element_type())
44 case DataType::FLOAT32:
49 throw std::runtime_error(
"luci-intp Rsqrt Unsupported type.");
53void Rsqrt::evalFloat()
const
55 auto in = getTensorData<float>(
input());
56 auto out = getTensorData<float>(
output());
58 for (
auto i = in; i != in +
size; ++i)
60 *out = 1.f / std::sqrt(*i);
void resize(const Shape &new_shape)
void execute() const override
void configure() override
Rsqrt(const Tensor *input, Tensor *output)
const Tensor * input() const
tflite::RuntimeShape getTensorShape(const Tensor *tensor)