36{
37 VERBOSE(TensorPlanner) <<
"Start planning non-constant tensors" << std::endl;
38
40
41
42 std::unordered_map<ir::train::TrainingOperandIndex, uint32_t> uses_map;
43 std::unordered_map<ir::train::TrainingOperandIndex, uint32_t> defs_map;
44
45
46
47 for (const auto &[operand_index, operand_usedefs] : training_usedefs)
48 {
49 const auto &operand = operand_usedefs.operand();
50
51 if (_external_operands.contains(operand_index.index()))
52 continue;
53
54 if (!operand_index.is_forward() || operand.isConstant())
55 continue;
56
57 uses_map[operand_index] = operand_usedefs.getTrainingUses().size();
58 defs_map[operand_index] = operand_usedefs.getTrainingDefs().size();
59 }
60
61
62
63
64
65
66
67 for (const auto &[operand_index, def_count] : defs_map)
68 {
69 if (def_count == 0)
70 tensor_builder->notifyFirstUse(operand_index.index());
71 }
72
73
74
75 std::vector<ir::train::TrainingOperandIndex> operands_last_until_end;
76 for (const auto &[operand_index, use_count] : uses_map)
77 {
78 if (use_count == 0)
79 operands_last_until_end.push_back(operand_index);
80 }
81
82
83
84
85
86
88 for (const auto &op_index : order)
89 {
93
94
95 for (const auto &output : op_outputs)
96 {
97 if (_external_operands.contains(output))
98 continue;
99 if (!tensor_builder->isRegistered(output))
100 continue;
101
103 assert(defs_map.find(output_index) != defs_map.end());
104 assert(defs_map.at(output_index) == 1);
105 defs_map[output_index] = 0;
106 tensor_builder->notifyFirstUse(output_index.index());
107 }
108
109
110
111
112
113 for (const auto &input : op_inputs)
114 {
115 if (_external_operands.contains(input))
116 continue;
117 if (!tensor_builder->isRegistered(input))
118 continue;
119
121 const auto &operand = training_usedefs.at(input_index).operand();
122 if (operand.isConstant())
123 continue;
124
125 assert(training_usedefs.find(input_index) != training_usedefs.end());
126 if (operand.info().isVariable())
127 throw std::runtime_error("The train backend does not support variable tensors");
128 }
129
130 for (const auto &input : op_inputs)
131 {
132 if (_external_operands.contains(input))
133 continue;
134 if (!tensor_builder->isRegistered(input))
135 continue;
136
138 const auto &operand = training_usedefs.at(input_index).operand();
139 if (operand.isConstant())
140 continue;
141
142 assert(uses_map.find(input_index) != uses_map.end());
143 assert(uses_map[input_index] > 0);
145 if (uses_map[input_index] == 0)
146 {
147
148 tensor_builder->notifyLastUse(
input_index.index());
149 }
150 }
151 }
152
153
155 for (const auto &op_index : border)
156 {
160
161 for (const auto &index : op_inputs + op_outputs)
162 {
163 if (_external_operands.contains(index))
164 continue;
165 if (!tensor_builder->isRegistered(index))
166 continue;
167
169 assert(training_usedefs.find(operand_index) != training_usedefs.end());
170 const auto &operand_usedefs = training_usedefs.at(operand_index);
171 const auto &operand = operand_usedefs.operand();
172 if (operand.isConstant())
173 continue;
174
176 assert(operand_usedefs.getTrainingDefs().find(training_op_index) ==
177 operand_usedefs.getTrainingDefs().end());
178
179 const auto &uses = operand_usedefs.getTrainingUses();
180 if (uses.find(training_op_index) != uses.end())
181 {
182 assert(uses_map.find(operand_index) != uses_map.end());
183 assert(uses_map[operand_index] > 0);
184 uses_map[operand_index]--;
185 if (uses_map[operand_index] == 0)
186 {
187
188 tensor_builder->notifyLastUse(operand_index.index());
189 }
190 }
191 }
192 }
193
194 for (const auto &operand_index : operands_last_until_end)
195 {
196 tensor_builder->notifyLastUse(operand_index.index());
197 }
198
199 assert(std::all_of(
200 uses_map.begin(), uses_map.end(),
201 [](std::pair<const ir::train::TrainingOperandIndex, uint32_t> it) { return it.second == 0; }));
202
203 assert(std::all_of(
204 defs_map.begin(), defs_map.end(),
205 [](std::pair<const ir::train::TrainingOperandIndex, uint32_t> it) { return it.second == 0; }));
206
207 VERBOSE(TensorPlanner) <<
"Finish planning non-constant tensors" << std::endl;
208}