42 {
43 if (shape_x.DimensionsCount() < 2 || shape_y.DimensionsCount() < 2)
44 return;
45
46 std::vector<int32_t> x;
47 std::vector<int32_t> y;
48
49 x.resize(shape_x.DimensionsCount() - 2);
50 y.resize(shape_y.DimensionsCount() - 2);
51
52 for (size_t i = 0; i < x.size(); i++)
53 {
54 x[i] = shape_x.Dims(i);
55 }
56 for (size_t i = 0; i < y.size(); i++)
57 {
58 y[i] = shape_y.Dims(i);
59 }
60
61 _batch_bcast = std::make_unique<BCast>(std::move(x), std::move(y));
62 if (!_batch_bcast->IsValid())
63 return;
64
65 const auto &x_reshaped = _batch_bcast->x_reshape();
66 const auto &y_reshaped = _batch_bcast->y_reshape();
68
69 _x_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
70 std::multiplies<int32_t>());
71 _y_batch_size = std::accumulate(y_reshaped.cbegin(), y_reshaped.cend(), INT32_C(1),
72 std::multiplies<int32_t>());
74 _output_batch_size = _output_shape.
FlatSize();
75 }
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
const luci_interpreter::RuntimeShape output_shape