ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SelectV2.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/SelectV2.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
22// TODO use select.h when version up
23// #include <tensorflow/lite/kernels/internal/reference/select.h>
24
25#include <stdexcept>
26
27namespace luci_interpreter
28{
29namespace kernels
30{
31
32SelectV2::SelectV2(const Tensor *condition, const Tensor *t, const Tensor *e, Tensor *output)
33 : Kernel({condition, t, e}, {output})
34{
35}
36
38{
39 LUCI_INTERPRETER_CHECK(condition()->element_type() == DataType::BOOL);
40 LUCI_INTERPRETER_CHECK(t()->element_type() == e()->element_type());
41 LUCI_INTERPRETER_CHECK(t()->element_type() == output()->element_type());
42
43 auto cond_shape = condition()->shape();
44 auto t_shape = t()->shape();
45 auto e_shape = e()->shape();
46
47 output()->resize(
48 calculateShapeForBroadcast(cond_shape, calculateShapeForBroadcast(t_shape, e_shape)));
49}
50
52{
53 auto t_type = t()->element_type();
54 switch (t_type)
55 {
56 case DataType::FLOAT32:
57 evaluate<float>();
58 break;
59 case DataType::S32:
60 evaluate<int32_t>();
61 break;
62 case DataType::S64:
63 evaluate<int64_t>();
64 break;
65 default:
66 throw std::runtime_error("luci-intp SelectV2 unsupported type.");
67 }
68}
69
70template <typename T> void SelectV2::evaluate() const
71{
72 const auto condition_shape = getTensorShape(condition());
73 const auto condition_data = getTensorData<bool>(condition());
74 const auto t_shape = getTensorShape(t());
75 const auto t_data = getTensorData<T>(t());
76 const auto e_shape = getTensorShape(e());
77 const auto e_data = getTensorData<T>(e());
78 const auto output_shape = getTensorShape(output());
79 auto output_data = getTensorData<T>(output());
80
81 tflite::reference_ops::BroadcastSelect5DSlow<bool, T>(
82 condition_shape, condition_data, t_shape, t_data, e_shape, e_data, output_shape, output_data);
83}
84
85} // namespace kernels
86} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
DataType element_type() const
Definition Tensor.h:105
const Tensor * condition() const
Definition SelectV2.h:33
const Tensor * t() const
Definition SelectV2.h:34
SelectV2(const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output)
Definition SelectV2.cpp:32
void execute() const override
Definition SelectV2.cpp:51
const Tensor * e() const
Definition SelectV2.h:35
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
Definition Utils.cpp:204
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194