61 assert(input_shapes[0].
Dims(1) == 1 && input_shapes[0].
Dims(3) == 1);
69 Tensor transformed_input[5];
72 const int num_inputs = input_shapes.size();
73 std::vector<InputTensor<float>> inputs(num_inputs);
74 for (
int i = 0; i < num_inputs; i++)
76 inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
77 inputs[i].buffer = input_data[i];
78 copyFrom<float>(inputs[i], inputs[i].shape, &transformed_input[i]);
83 output.buffer = output_data;
84 copyFrom<float>(output, output.shape, &transformed_output);
94 const int depth = x.dimension(3);
95 const int size = x.size();
96 const int rest_size =
size / depth;
97 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
99 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
100 Eigen::array<int, 1> reduce_dims({0});
101 Eigen::array<int, 2> bcast_spec({rest_size, 1});
103 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<float>();
104 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
105 float rest_size_inv =
static_cast<float>(1.0f /
static_cast<float>(rest_size));
107 [[maybe_unused]]
float rest_size_adjust =
108 static_cast<float>(rest_size) /
static_cast<float>(rest_size_minus_one);
110 Eigen::Tensor<float, 1, Eigen::RowMajor> batch_mean(depth);
111 Eigen::Tensor<float, 1, Eigen::RowMajor> batch_variance(depth);
115 batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
116 auto x_centered = x_rest_by_depth - batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
118 batch_variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
119 auto scaling_factor = ((batch_variance + param.
epsilon).rsqrt() * scale)
121 .reshape(one_by_depth)
122 .broadcast(bcast_spec);
123 auto x_scaled = x_centered * scaling_factor;
125 (x_scaled +
offset.reshape(one_by_depth).broadcast(bcast_spec)).
template cast<float>();
127 y.reshape(rest_by_depth).device(d) = x_shifted;
129 memcpy(output_data, y.data(),
output_shape.FlatSize() *
sizeof(
float));
136 temp_tensor.
shape.
ReplaceWith(input.shape.DimensionsCount(), input.shape.DimsData());
137 temp_operand.emplace_back(std::make_unique<
float[]>(input.shape.FlatSize()));
138 temp_tensor.
buffer = temp_operand.back().get();
139 memcpy(temp_tensor.
buffer, input.buffer, input.shape.FlatSize() *
sizeof(
float));
141 copyFrom(temp_tensor, shape, output);