65 const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition);
66 const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX);
67 const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY);
68 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
70 assert(input_cond_index != -1);
71 assert(input_x_index != -1);
72 assert(input_y_index != -1);
73 assert(output_index != -1);
80 assert(input_cond !=
nullptr);
81 assert(input_x !=
nullptr);
82 assert(input_y !=
nullptr);
91 bool possible_mixed_scaler =
92 Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 &&
93 Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1;
95 bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) &&
96 Tensor::num_elements(input_x) == Tensor::num_elements(input_y);
99 if (not same_shape and not possible_mixed_scaler)
107 const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition);
108 const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX);
109 const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY);
110 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
112 assert(input_cond_index != -1);
113 assert(input_x_index != -1);
114 assert(input_y_index != -1);
115 assert(output_index != -1);
122 assert(input_cond !=
nullptr);
123 assert(input_x !=
nullptr);
124 assert(input_y !=
nullptr);
126 bool possible_mixed_scaler =
127 Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 &&
128 Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1;
130 bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) &&
131 Tensor::num_elements(input_x) == Tensor::num_elements(input_y);
132 bool is_broadcast =
false;
133 if (not possible_mixed_scaler and not same_shape)
136 const auto type = Tensor::element_type(input_x);
140 case DataType::FLOAT32:
141 CallSelect<float>(input_cond, input_x, input_y, output, is_broadcast, runtime_graph);
145 assert(
false &&
"Unsupported type.");