32{
33 VERBOSE(TensorPlanner) <<
"Start planning non-constant tensors" << std::endl;
34
36
37
38 std::unordered_map<ir::train::TrainingOperandIndex, uint32_t> uses_map;
39 std::unordered_map<ir::train::TrainingOperandIndex, uint32_t> defs_map;
40
41
42
43 for (const auto &[operand_index, operand_usedefs] : training_usedefs)
44 {
45 const auto &operand = operand_usedefs.operand();
46
47 if (_external_operands.contains(operand_index.index()))
48 continue;
49
50 if (!operand_index.is_forward() || operand.isConstant())
51 continue;
52
53 uses_map[operand_index] = operand_usedefs.getTrainingUses().size();
54 defs_map[operand_index] = operand_usedefs.getTrainingDefs().size();
55 }
56
57
58
59
60
61
62
63 for (const auto &[operand_index, def_count] : defs_map)
64 {
65 if (def_count == 0)
66 tensor_builder->notifyFirstUse(operand_index.index());
67 }
68
69
70
71 std::vector<ir::train::TrainingOperandIndex> operands_last_until_end;
72 for (const auto &[operand_index, use_count] : uses_map)
73 {
74 if (use_count == 0)
75 operands_last_until_end.push_back(operand_index);
76 }
77
78
79
80
81
82
84 for (const auto &op_index : order)
85 {
89
90
91 for (const auto &output : op_outputs)
92 {
93 if (_external_operands.contains(output))
94 continue;
95 if (!tensor_builder->isRegistered(output))
96 continue;
97
99 assert(defs_map.find(output_index) != defs_map.end());
100 assert(defs_map.at(output_index) == 1);
101 defs_map[output_index] = 0;
102 tensor_builder->notifyFirstUse(output_index.index());
103 }
104
105
106
107
108
109 for (const auto &input : op_inputs)
110 {
111 if (_external_operands.contains(input))
112 continue;
113 if (!tensor_builder->isRegistered(input))
114 continue;
115
117 const auto &operand = training_usedefs.at(input_index).operand();
118 if (operand.isConstant())
119 continue;
120
121 assert(training_usedefs.find(input_index) != training_usedefs.end());
122 if (operand.info().isVariable())
123 throw std::runtime_error("The train backend does not support variable tensors");
124 }
125
126 for (const auto &input : op_inputs)
127 {
128 if (_external_operands.contains(input))
129 continue;
130 if (!tensor_builder->isRegistered(input))
131 continue;
132
134 const auto &operand = training_usedefs.at(input_index).operand();
135 if (operand.isConstant())
136 continue;
137
138 assert(uses_map.find(input_index) != uses_map.end());
139 assert(uses_map[input_index] > 0);
141 if (uses_map[input_index] == 0)
142 {
143
144 tensor_builder->notifyLastUse(
input_index.index());
145 }
146 }
147 }
148
149
151 for (const auto &op_index : border)
152 {
156
157 for (const auto &index : op_inputs + op_outputs)
158 {
159 if (_external_operands.contains(index))
160 continue;
161 if (!tensor_builder->isRegistered(index))
162 continue;
163
165 assert(training_usedefs.find(operand_index) != training_usedefs.end());
166 const auto &operand_usedefs = training_usedefs.at(operand_index);
167 const auto &operand = operand_usedefs.operand();
168 if (operand.isConstant())
169 continue;
170
172 assert(operand_usedefs.getTrainingDefs().find(training_op_index) ==
173 operand_usedefs.getTrainingDefs().end());
174
175 const auto &uses = operand_usedefs.getTrainingUses();
176 if (uses.find(training_op_index) != uses.end())
177 {
178 assert(uses_map.find(operand_index) != uses_map.end());
179 assert(uses_map[operand_index] > 0);
180 uses_map[operand_index]--;
181 if (uses_map[operand_index] == 0)
182 {
183
184 tensor_builder->notifyLastUse(operand_index.index());
185 }
186 }
187 }
188 }
189
190 for (const auto &operand_index : operands_last_until_end)
191 {
192 tensor_builder->notifyLastUse(operand_index.index());
193 }
194
195 assert(std::all_of(
196 uses_map.begin(), uses_map.end(),
197 [](std::pair<const ir::train::TrainingOperandIndex, uint32_t> it) { return it.second == 0; }));
198
199 assert(std::all_of(
200 defs_map.begin(), defs_map.end(),
201 [](std::pair<const ir::train::TrainingOperandIndex, uint32_t> it) { return it.second == 0; }));
202
203 VERBOSE(TensorPlanner) <<
"Finish planning non-constant tensors" << std::endl;
204}