ONE - On-device Neural Engine
Loading...
Searching...
No Matches
xent_op.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_EIGEN_XENT_OPS_H__
19#define __NNFW_CKER_EIGEN_XENT_OPS_H__
20
21// From tensorflow/core/kernels/xent_op.cc
22#define EIGEN_USE_THREADS
23
24#include "unsupported/Eigen/CXX11/Tensor"
26
27// From tensorflow/core/kernels/xent_op.h
28namespace nnfw
29{
30namespace cker
31{
32namespace xent_ops
33{
34namespace functor
35{
36
37// Functor used by XentOp to do the computations.
38template <typename Device, typename T> struct XentFunctor
39{
40 // Computes Cross Entropy loss and backprop.
41 //
42 // logits: batch_size, num_classes.
43 // labels: batch_size, num_classes.
44 // scratch: temporary tensor, dims: batch_size, 1
45 // loss: output tensor for the loss, dims: batch_size.
46 // backprop: output tensor for the backprop, dims: batch_size, num_classes.
47 void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
48 const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
49 const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
50 typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
51 typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
52 typename TTypes<T>::Matrix backprop, T reduction_size);
53};
54
55} // namespace functor
56} // namespace xent_ops
57} // namespace cker
58} // namespace nnfw
59
60// From tensorflow/core/kernels/xent_op.cc
61namespace nnfw
62{
63namespace cker
64{
65namespace xent_ops
66{
67
68// Enable CPUDevice only for xent_ops
69using CPUDevice = Eigen::ThreadPoolDevice;
70using Index = Eigen::Index;
71
72// Partial specialization for a CPUDevice, that uses the Eigen implementation
73// from XentEigenImpl.
74namespace functor
75{
76template <typename Device, typename T> struct XentFunctorBase
77{
78 void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
79 const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
80 const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
81 typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
82 typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
83 typename TTypes<T>::Matrix backprop, T reduction_size)
84 {
85 T *scratch_ptr = scratch.data();
86 T *backprop_ptr = backprop.data();
87
88 T *loss_ptr = loss.data();
89
90 int row_size = shape[1];
91
92 if (shape[0] > 0)
93 {
94 backprop.device(d) = logits.broadcast(logits_bcast);
95 scratch.device(d) = labels.broadcast(labels_bcast);
96 auto reductionWorker = [&](int64_t begin, int64_t end) -> void {
97 for (int i = begin; i < end; i++)
98 {
99 T *this_backprop = backprop_ptr + (i * row_size);
100 T *this_logits = backprop_ptr + (i * row_size);
101 T *this_labels = scratch_ptr + (i * row_size);
102 T max_logits = this_logits[0];
103
104 // calculating max_logits
105 for (int j = 1; j < row_size; j++)
106 {
107 max_logits = std::max(max_logits, this_logits[j]);
108 }
109
110 T sum = T(0);
111 T loss_sum = T(0);
112
113 for (int j = 0; j < row_size; j++)
114 {
115 // Note that if input is reused than this_logits and this_backprop
116 // is same buffer, so after this calculation this_logits should no
117 // longer be trusted
118 this_backprop[j] = this_logits[j] - max_logits;
119 sum = sum + exp(this_backprop[j]);
120 }
121
122 // loss calculation
123 T log_sum = log(sum);
124 for (int j = 0; j < row_size; j++)
125 {
126 loss_sum += this_labels[j] * (log_sum - this_backprop[j]);
127 this_backprop[j] = ((exp(this_backprop[j]) / sum) - this_labels[j]) / reduction_size;
128 }
129 loss_ptr[i] = loss_sum;
130 }
131 };
132 const int64_t compute_cycles = 50 * row_size;
133 const int64_t input_bytes = sizeof(T) * row_size;
134 const int64_t output_bytes = sizeof(T) * row_size;
135 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
136
137 d.parallelFor(shape[0], cost, reductionWorker);
138 }
139 }
140};
141
142template <typename T> struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T>
143{
144};
145
146} // namespace functor
147} // namespace xent_ops
148} // namespace cker
149} // namespace nnfw
150
151#endif // __NNFW_CKER_EIGEN_XENT_OPS_H__
Eigen::Index Index
Definition xent_op.h:70
Eigen::ThreadPoolDevice CPUDevice
Definition xent_op.h:69
ShapeIterator end(const Shape &s)
Definition topk_v2.h:30
int32_t begin[5]
Definition Slice.cpp:33
Eigen::TensorMap< Eigen::Tensor< T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > Vec
Definition Tensor.h:64
Eigen::TensorMap< Eigen::Tensor< T, 2, Eigen::RowMajor, IndexType >, Eigen::Aligned > Matrix
Definition Tensor.h:76
Eigen::TensorMap< Eigen::Tensor< const T, 2, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstMatrix
Definition Tensor.h:78
void operator()(const Device &d, const Eigen::DSizes< Eigen::DenseIndex, 2 > &shape, const Eigen::array< Eigen::DenseIndex, 2 > &logits_bcast, const Eigen::array< Eigen::DenseIndex, 2 > &labels_bcast, typename TTypes< T >::ConstMatrix logits, typename TTypes< T >::ConstMatrix labels, typename TTypes< T >::Matrix scratch, typename TTypes< T >::Vec loss, typename TTypes< T >::Matrix backprop, T reduction_size)
Definition xent_op.h:78
void operator()(const Device &d, const Eigen::DSizes< Eigen::DenseIndex, 2 > &shape, const Eigen::array< Eigen::DenseIndex, 2 > &logits_bcast, const Eigen::array< Eigen::DenseIndex, 2 > &labels_bcast, typename TTypes< T >::ConstMatrix logits, typename TTypes< T >::ConstMatrix labels, typename TTypes< T >::Matrix scratch, typename TTypes< T >::Vec loss, typename TTypes< T >::Matrix backprop, T reduction_size)