|
| 1 | +import copy |
1 | 2 | import inspect
|
2 | 3 | import math
|
3 | 4 | import re
|
|
10 | 11 | import torch
|
11 | 12 | import torchvision
|
12 | 13 | from torch import fx, nn
|
13 |
| -from torch.fx.graph_module import _copy_attr |
| 14 | +from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY |
14 | 15 |
|
15 | 16 |
|
16 | 17 | __all__ = ["create_feature_extractor", "get_graph_node_names"]
|
@@ -330,6 +331,40 @@ def train(self, mode=True):
|
330 | 331 | self.graph = self.eval_graph
|
331 | 332 | return super().train(mode=mode)
|
332 | 333 |
|
| 334 | + def _deepcopy_init(self): |
| 335 | + # See __deepcopy__ below |
| 336 | + return DualGraphModule.__init__ |
| 337 | + |
| 338 | + def __deepcopy__(self, memo): |
| 339 | + # Same as the base class' __deepcopy__ from pytorch, with minor |
| 340 | + # modification to account for train_graph and eval_graph |
| 341 | + # https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875 |
| 342 | + # |
| 343 | + # This is using a bunch of private stuff from torch, so if that breaks, |
| 344 | + # we'll likely have to remove this, along with the associated |
| 345 | + # non-regression test. |
| 346 | + res = type(self).__new__(type(self)) |
| 347 | + memo[id(self)] = res |
| 348 | + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) |
| 349 | + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"]) |
| 350 | + |
| 351 | + extra_preserved_attrs = [ |
| 352 | + "_state_dict_hooks", |
| 353 | + "_load_state_dict_pre_hooks", |
| 354 | + "_load_state_dict_post_hooks", |
| 355 | + "_replace_hook", |
| 356 | + "_create_node_hooks", |
| 357 | + "_erase_node_hooks", |
| 358 | + ] |
| 359 | + for attr in extra_preserved_attrs: |
| 360 | + if attr in self.__dict__: |
| 361 | + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) |
| 362 | + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) |
| 363 | + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: |
| 364 | + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): |
| 365 | + setattr(res, attr_name, attr) |
| 366 | + return res |
| 367 | + |
333 | 368 |
|
334 | 369 | def create_feature_extractor(
|
335 | 370 | model: nn.Module,
|
|
0 commit comments