Skip to content

Commit e21cd0b

Browse files
__deepcopy__ for DualGraphModule (#8708)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 22e86bd commit e21cd0b

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

test/test_backbone_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from copy import deepcopy
23
from itertools import chain
34
from typing import Mapping, Sequence
45

@@ -322,3 +323,14 @@ def forward(self, x):
322323
out = model(self.inp)
323324
# And backward
324325
out["leaf_module"].float().mean().backward()
326+
327+
def test_deepcopy(self):
328+
# Non-regression test for https://github.com/pytorch/vision/issues/8634
329+
model = models.efficientnet_b3(weights=None)
330+
extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})
331+
332+
extractor.eval()
333+
extractor.train()
334+
extractor = deepcopy(extractor)
335+
extractor.eval()
336+
extractor.train()

torchvision/models/feature_extraction.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import inspect
23
import math
34
import re
@@ -10,7 +11,7 @@
1011
import torch
1112
import torchvision
1213
from torch import fx, nn
13-
from torch.fx.graph_module import _copy_attr
14+
from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY
1415

1516

1617
__all__ = ["create_feature_extractor", "get_graph_node_names"]
@@ -330,6 +331,40 @@ def train(self, mode=True):
330331
self.graph = self.eval_graph
331332
return super().train(mode=mode)
332333

334+
def _deepcopy_init(self):
335+
# See __deepcopy__ below
336+
return DualGraphModule.__init__
337+
338+
def __deepcopy__(self, memo):
339+
# Same as the base class' __deepcopy__ from pytorch, with minor
340+
# modification to account for train_graph and eval_graph
341+
# https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875
342+
#
343+
# This is using a bunch of private stuff from torch, so if that breaks,
344+
# we'll likely have to remove this, along with the associated
345+
# non-regression test.
346+
res = type(self).__new__(type(self))
347+
memo[id(self)] = res
348+
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
349+
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"])
350+
351+
extra_preserved_attrs = [
352+
"_state_dict_hooks",
353+
"_load_state_dict_pre_hooks",
354+
"_load_state_dict_post_hooks",
355+
"_replace_hook",
356+
"_create_node_hooks",
357+
"_erase_node_hooks",
358+
]
359+
for attr in extra_preserved_attrs:
360+
if attr in self.__dict__:
361+
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
362+
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
363+
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
364+
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
365+
setattr(res, attr_name, attr)
366+
return res
367+
333368

334369
def create_feature_extractor(
335370
model: nn.Module,

0 commit comments

Comments
 (0)