78 void operator()(
const Device &d,
const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
79 const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
80 const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
85 T *scratch_ptr = scratch.data();
86 T *backprop_ptr = backprop.data();
88 T *loss_ptr = loss.data();
90 int row_size = shape[1];
94 backprop.device(d) = logits.broadcast(logits_bcast);
95 scratch.device(d) = labels.broadcast(labels_bcast);
96 auto reductionWorker = [&](int64_t
begin, int64_t
end) ->
void {
99 T *this_backprop = backprop_ptr + (i * row_size);
100 T *this_logits = backprop_ptr + (i * row_size);
101 T *this_labels = scratch_ptr + (i * row_size);
102 T max_logits = this_logits[0];
105 for (
int j = 1; j < row_size; j++)
107 max_logits = std::max(max_logits, this_logits[j]);
113 for (
int j = 0; j < row_size; j++)
118 this_backprop[j] = this_logits[j] - max_logits;
119 sum = sum + exp(this_backprop[j]);
123 T log_sum = log(sum);
124 for (
int j = 0; j < row_size; j++)
126 loss_sum += this_labels[j] * (log_sum - this_backprop[j]);
127 this_backprop[j] = ((exp(this_backprop[j]) / sum) - this_labels[j]) / reduction_size;
129 loss_ptr[i] = loss_sum;
132 const int64_t compute_cycles = 50 * row_size;
133 const int64_t input_bytes =
sizeof(T) * row_size;
134 const int64_t output_bytes =
sizeof(T) * row_size;
135 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
137 d.parallelFor(shape[0], cost, reductionWorker);