18#ifndef LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
19#define LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
51 for (
int i = N - 1; i >= 0; --i)
54 desc_out->
strides[i] = desc_stride;
55 desc_stride *= input_shape.
dims(i);
59template <
int N,
int DIM,
typename Calc>
61 const Calc &calc,
int indexes[N])
63 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
69template <
int N,
int DIM,
typename Calc>
71 const Calc &calc,
int indexes[N])
73 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
75 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
81template <
int N,
typename Calc>
85 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
99 copyDimsToDesc<N>(extended_input0_shape, desc0_out);
100 copyDimsToDesc<N>(extended_input1_shape, desc1_out);
105 for (
int i = 0; i < N; ++i)
107 const int extent0 = extended_input0_shape.dims(i);
108 const int extent1 = extended_input1_shape.dims(i);
109 if (extent0 != extent1)
114 desc0_out->
extents[i] = extent1;
119 desc1_out->
extents[i] = extent0;
132 return indexes[0] * desc.
strides[0] + indexes[1] * desc.
strides[1] +
162 if (extended_shape0 == extended_shape1)
179 for (
int i = dims_count - 1; i >= 0; --i)
181 if (extended_shape0.dims(i) == extended_shape1.dims(i))
185 else if (extended_shape0.dims(i) == 1)
190 else if (extended_shape1.dims(i) == 1)
int32_t dimensionsCount() const
int32_t dims(int i) const
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
void copyDimsToDesc(const luci_interpreter::RuntimeShape &input_shape, NdArrayDesc< N > *desc_out)
@ kSecondInputBroadcastsFast
@ kFirstInputBroadcastsFast
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
bool ProcessBroadcastShapes(const luci_interpreter::RuntimeShape &shape0, const luci_interpreter::RuntimeShape &shape1, luci_interpreter_pal::ArithmeticParams *params)
std::enable_if< DIM==N-1, void >::type NDOpsHelperImpl(const NdArrayDesc< N > &output, const Calc &calc, int indexes[N])
void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape, const luci_interpreter::RuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void NDOpsHelper(const NdArrayDesc< N > &output, const Calc &calc)
BroadcastableOpCategory broadcast_category