ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FloorDiv.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/FloorDiv.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
22#include <cmath>
23
24namespace luci_interpreter
25{
26
27namespace kernels
28{
29
30FloorDiv::FloorDiv(const Tensor *input, const Tensor *alpha, Tensor *output)
31 : Kernel({input, alpha}, {output})
32{
33}
34
36{
37 LUCI_INTERPRETER_CHECK(x()->element_type() == output()->element_type());
38 LUCI_INTERPRETER_CHECK(y()->element_type() == output()->element_type());
39
40 output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
41}
42
44{
45 switch (x()->element_type())
46 {
47 case DataType::FLOAT32:
48 evalFloat();
49 break;
50 default:
51 throw std::runtime_error("luci-intp FloorDiv Unsupported type.");
52 }
53}
54
55void FloorDiv::evalFloat() const
56{
57 auto FloorDivFunc = [](float x, float y) -> float {
58 return std::floor(static_cast<double>(x) / static_cast<double>(y));
59 };
60
61 const auto x_data = getTensorData<float>(x());
62 const auto y_data = getTensorData<float>(y());
63
64 // Check the denominator
65 for (int i = 0; i < getTensorShape(y()).FlatSize(); ++i)
66 {
67 LUCI_INTERPRETER_CHECK(y_data[i] != 0);
68 }
69
70 if (x()->shape() != y()->shape())
71 {
72 tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
73 getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
74 getTensorData<float>(output()), FloorDivFunc);
75 }
76 else
77 {
78 tflite::reference_ops::BinaryFunction<float, float, float>(
79 getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
80 getTensorData<float>(output()), FloorDivFunc);
81 }
82}
83
84} // namespace kernels
85} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
void execute() const override
Definition FloorDiv.cpp:43
const Tensor * y() const
Definition FloorDiv.h:33
const Tensor * x() const
Definition FloorDiv.h:32
FloorDiv(const Tensor *x, const Tensor *y, Tensor *output)
Definition FloorDiv.cpp:30
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
Definition Utils.cpp:204
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194