60 assert(input_shapes[0].
Dims(1) == 1 && input_shapes[0].
Dims(3) == 1);
68 Tensor transformed_input[5];
71 const int num_inputs = input_shapes.size();
72 std::vector<InputTensor<float>> inputs(num_inputs);
73 for (
int i = 0; i < num_inputs; i++)
75 inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
76 inputs[i].buffer = input_data[i];
77 copyFrom<float>(inputs[i], inputs[i].shape, &transformed_input[i]);
82 output.buffer = output_data;
83 copyFrom<float>(output, output.shape, &transformed_output);
93 const int depth = x.dimension(3);
94 const int size = x.size();
95 const int rest_size =
size / depth;
96 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
98 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
99 Eigen::array<int, 1> reduce_dims({0});
100 Eigen::array<int, 2> bcast_spec({rest_size, 1});
102 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<float>();
103 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
104 float rest_size_inv =
static_cast<float>(1.0f /
static_cast<float>(rest_size));
106 [[maybe_unused]]
float rest_size_adjust =
107 static_cast<float>(rest_size) /
static_cast<float>(rest_size_minus_one);
109 Eigen::Tensor<float, 1, Eigen::RowMajor> batch_mean(depth);
110 Eigen::Tensor<float, 1, Eigen::RowMajor> batch_variance(depth);
114 batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
115 auto x_centered = x_rest_by_depth - batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
117 batch_variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
118 auto scaling_factor = ((batch_variance + param.
epsilon).rsqrt() * scale)
120 .reshape(one_by_depth)
121 .broadcast(bcast_spec);
122 auto x_scaled = x_centered * scaling_factor;
124 (x_scaled +
offset.reshape(one_by_depth).broadcast(bcast_spec)).
template cast<float>();
126 y.reshape(rest_by_depth).device(d) = x_shifted;
128 memcpy(output_data, y.data(),
output_shape.FlatSize() *
sizeof(
float));
135 temp_tensor.
shape.
ReplaceWith(input.shape.DimensionsCount(), input.shape.DimsData());
136 temp_operand.emplace_back(std::make_unique<
float[]>(input.shape.FlatSize()));
137 temp_tensor.
buffer = temp_operand.back().get();
138 memcpy(temp_tensor.
buffer, input.buffer, input.shape.FlatSize() *
sizeof(
float));
140 copyFrom(temp_tensor, shape, output);