18#ifndef ONERT_MICRO_EXECUTE_PAL_PROCESS_BROADCAST_SHAPES_H
19#define ONERT_MICRO_EXECUTE_PAL_PROCESS_BROADCAST_SHAPES_H
56 for (
int i = N - 1; i >= 0; --i)
59 desc_out->
strides[i] = desc_stride;
60 desc_stride *= input_shape.
dims(i);
64template <
int N,
int DIM,
typename Calc>
66 const Calc &calc,
int indexes[N])
68 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
74template <
int N,
int DIM,
typename Calc>
76 const Calc &calc,
int indexes[N])
78 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
80 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
86template <
int N,
typename Calc>
90 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
104 copyDimsToDesc<N>(extended_input0_shape, desc0_out);
105 copyDimsToDesc<N>(extended_input1_shape, desc1_out);
110 for (
int i = 0; i < N; ++i)
112 const int extent0 = extended_input0_shape.dims(i);
113 const int extent1 = extended_input1_shape.dims(i);
114 if (extent0 != extent1)
119 desc0_out->
extents[i] = extent1;
124 desc1_out->
extents[i] = extent0;
137 return indexes[0] * desc.
strides[0] + indexes[1] * desc.
strides[1] +
167 if (extended_shape0 == extended_shape1)
184 for (
int i = dims_count - 1; i >= 0; --i)
186 if (extended_shape0.dims(i) == extended_shape1.dims(i))
190 else if (extended_shape0.dims(i) == 1)
195 else if (extended_shape1.dims(i) == 1)
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
int32_t dims(int i) const
int32_t dimensionsCount() const
@ kSecondInputBroadcastsFast
@ kFirstInputBroadcastsFast
std::enable_if< DIM==N-1, void >::type NDOpsHelperImpl(const NdArrayDesc< N > &output, const Calc &calc, int indexes[N])
bool processBroadcastShapes(const core::OMRuntimeShape &shape0, const core::OMRuntimeShape &shape1, core::BinaryArithmeticBroadcastParams *params)
void NdArrayDescsForElementwiseBroadcast(const core::OMRuntimeShape &input0_shape, const core::OMRuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void copyDimsToDesc(const core::OMRuntimeShape &input_shape, NdArrayDesc< N > *desc_out)
void NDOpsHelper(const NdArrayDesc< N > &output, const Calc &calc)
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
BroadcastableOpCategory broadcast_category