39 const std::shared_ptr<TensorRegistry> &tensor_reg,
40 const std::shared_ptr<ExternalContext> &external_context,
62 std::shared_ptr<TensorRegistry> _tensor_reg;
63 const std::shared_ptr<ExternalContext> _external_context;
65 std::vector<std::unique_ptr<exec::train::IGradientApplier>> _update_funcs;
66 std::unordered_map<const ir::IOperation *, ir::OperationIndex> _node_to_idx;