105 assert(context !=
nullptr);
118 assert(opinputs.size() > 0);
119 assert(opoutputs.size() == 1);
122 int32_t concat_axis = 0;
123 tflite::ActivationFunctionType activation = tflite::ActivationFunctionType_NONE;
125 if (
auto *concatenation_params = op->builtin_options_as_ConcatenationOptions())
127 activation = concatenation_params->fused_activation_function();
128 concat_axis = concatenation_params->axis();
130 const int32_t rank =
static_cast<int32_t
>(tensor_context.
shape(opinputs.at(0)).
rank());
135 assert(concat_axis >= 0);
136 assert(concat_axis < rank);
139 assert(activation == tflite::ActivationFunctionType_NONE);
142 std::vector<coco::FeatureObject *> input_objects;
144 for (
auto &input_index : opinputs)
149 input_obj->
bag(input_bag);
152 input_objects.emplace_back(input_obj);
157 assert(last_feature !=
nullptr);
158 assert(last_feature->
bag() !=
nullptr);
172 for (uint32_t n = 1; n < input_objects.size(); ++n)
174 auto const left_feature = last_feature;
175 auto const left_shape = left_feature->
layout()->
shape();
177 auto right_feature = input_objects.at(n);
178 auto right_shape = right_feature->layout()->shape();
181 auto compute_out_dims = [&left_shape, &right_shape, concat_axis](void) {
182 std::array<uint32_t, 4> out_dims;
184 const auto left_dims = as_dims(left_shape);
185 const auto right_dims = as_dims(right_shape);
187 for (uint32_t axis = 0; axis < 4 ; ++axis)
190 assert((concat_axis == axis) || (left_dims[axis] == right_dims[axis]));
192 out_dims[axis] = left_dims[axis];
193 if (axis == concat_axis)
195 out_dims[axis] += right_dims[axis];
202 const auto out_dims = compute_out_dims();
204 const uint32_t B = out_dims[0 ];
205 const uint32_t C = out_dims[3 ];
206 const uint32_t H = out_dims[1 ];
207 const uint32_t W = out_dims[2 ];
211 auto out_bag =
m->entity()->bag()->create(B * num_elements(out_shape));
214 out_feature->
bag(out_bag);
222 concat_f->
axis(as_ConcatF_axis(concat_axis));
223 concat_f->left(left_load);
224 concat_f->right(right_load);
232 last_feature = out_feature;
236 int const ofm_idx = opoutputs.at(0);
237 auto const ofm_shape = tensor_context.
shape(ofm_idx);
239 auto ofm_bag = bags.
bag(ofm_idx);
242 ofm_obj->
bag(ofm_bag);
Class to store context to build IR from tflite.
Extracts and holds operand(tensor) information such as name, shape, and type.
const tensor::Shape & shape(uint32_t tensor_id)