39 if (_input_shape.
rank() < 1 || _weights_shape.
rank() != 2)
42 auto const input_elems = element_count(&_input_shape);
43 auto const weights_height = _weights_shape.
dim(0).
value();
44 auto const weights_width = _weights_shape.
dim(1).
value();
45 if (weights_height == 0 || weights_width == 0)
47 if (input_elems % weights_width != 0)
49 auto const batch_size = input_elems / weights_width;
50 auto const num_units = weights_height;
53 if (element_count(&_bias_shape) != num_units)
61 _output_shape.
rank(_input_shape.
rank());
62 for (uint32_t i = 0; i < _input_shape.
rank(); i++)
63 _output_shape.
dim(i) = _input_shape.
dim(i);
64 _output_shape.
dim(_input_shape.
rank() - 1) = num_units;
68 _output_shape.
rank(2);
69 _output_shape.
dim(0) = batch_size;
70 _output_shape.
dim(1) = num_units;