18#include "kernels/Utils.h"
20#include <tensorflow/lite/kernels/internal/reference/comparisons.h>
40 if (
x()->element_type() == DataType::U8)
50 switch (
x()->element_type())
52 case DataType::FLOAT32:
56 evalInteger<int64_t>();
59 evalInteger<int32_t>();
65 throw std::runtime_error(
"luci-intp GreaterEqual Unsupported type.");
69void GreaterEqual::evalFloat()
const
71 const auto x_data = getTensorData<float>(
x());
72 const auto y_data = getTensorData<float>(
y());
73 auto output_data = getTensorData<bool>(
output());
75 tflite::ComparisonParams op_params;
78 if (op_params.is_broadcast)
80 tflite::reference_ops::Broadcast4DSlowGreaterEqual(op_params,
getTensorShape(
x()), x_data,
91template <
typename T>
void GreaterEqual::evalInteger()
const
93 const auto x_data = getTensorData<T>(
x());
94 const auto y_data = getTensorData<T>(
y());
97 tflite::ComparisonParams op_params;
100 if (op_params.is_broadcast)
102 tflite::reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
108 tflite::reference_ops::GreaterEqualNoScaling(op_params,
getTensorShape(
x()), x_data,
114void GreaterEqual::evalQuantized()
const
116 const auto x_data = getTensorData<uint8_t>(
x());
117 const auto y_data = getTensorData<uint8_t>(
y());
120 tflite::ComparisonParams op_params;
121 op_params.left_shift = 8;
123 op_params.input1_shift = _x_shift;
124 op_params.input1_multiplier = _x_multiplier;
126 op_params.input2_shift = _y_shift;
127 op_params.input2_multiplier = _y_multiplier;
130 if (op_params.is_broadcast)
132 tflite::reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
138 tflite::reference_ops::GreaterEqualWithScaling(op_params,
getTensorShape(
x()), x_data,
void resize(const Shape &new_shape)
const Shape & shape() const
int32_t zero_point() const
GreaterEqual(const Tensor *x, const Tensor *y, Tensor *output)
void configure() override
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void quantizeMultiplierSmallerThanOneExp(double double_multiplier, int32_t *quantized_multiplier, int *left_shift)