ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Select.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Select.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
22// TODO use select.h when version up
23// #include <tensorflow/lite/kernels/internal/reference/select.h>
24
25#include <stdexcept>
26
27namespace luci_interpreter
28{
29
30namespace kernels
31{
32
33Select::Select(const Tensor *condition, const Tensor *t, const Tensor *e, Tensor *output)
34 : Kernel({condition, t, e}, {output})
35{
36 _has_low_rank_input_condition = false;
37}
38
40{
41 LUCI_INTERPRETER_CHECK(condition()->element_type() == DataType::BOOL);
42 LUCI_INTERPRETER_CHECK(t()->element_type() == e()->element_type());
43 LUCI_INTERPRETER_CHECK(t()->element_type() == output()->element_type());
44
45 auto cond_shape = condition()->shape();
46 auto cond_num_dims = cond_shape.num_dims();
47 auto t_shape = t()->shape();
48
49 bool is_input_condition_scalar = cond_num_dims == 0;
50 bool has_rank_one_input_condition = cond_num_dims == 1 && cond_shape.dim(0) == t_shape.dim(0);
51
52 _has_low_rank_input_condition = is_input_condition_scalar || has_rank_one_input_condition;
53
54 output()->resize(calculateShapeForBroadcast(t()->shape(), e()->shape()));
55}
56
57void Select::execute() const
58{
59 switch (t()->element_type())
60 {
61 case DataType::FLOAT32:
62 evalFloat();
63 break;
64 default:
65 throw std::runtime_error("luci-intp Select unsupported type.");
66 }
67}
68
69void Select::evalFloat() const
70{
71 const auto condition_shape = getTensorShape(condition());
72 const auto condition_data = getTensorData<bool>(condition());
73 const auto t_shape = getTensorShape(t());
74 const auto t_data = getTensorData<float>(t());
75 const auto e_shape = getTensorShape(e());
76 const auto e_data = getTensorData<float>(e());
77 const auto output_shape = getTensorShape(output());
78 auto output_data = getTensorData<float>(output());
79
80 if (_has_low_rank_input_condition)
81 {
82 tflite::reference_ops::RankOneSelect(condition_shape, condition_data, t_shape, t_data, e_shape,
83 e_data, output_shape, output_data);
84 }
85 else
86 {
87 tflite::reference_ops::Select(condition_shape, condition_data, t_shape, t_data, e_shape, e_data,
88 output_shape, output_data);
89 }
90}
91
92} // namespace kernels
93} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
Select(const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output)
Definition Select.cpp:33
void execute() const override
Definition Select.cpp:57
const Tensor * e() const
Definition Select.h:35
const Tensor * t() const
Definition Select.h:34
const Tensor * condition() const
Definition Select.h:33
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
Definition Utils.cpp:204
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void Select(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)
void RankOneSelect(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)