Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace usage of copy.deepcopy() - Convolution/Batch Norm Fuser in FX #2645

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions intermediate_source/fx_conv_bn_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn):
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size)
Copy link
Member

@msaroufim msaroufim Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems weird? The right to do feels like its implementing a proper __deepcopy__() for nn modules? @albanD

This popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems weird? The right to do feels like its implementing a proper __deepcopy__() for nn modules? @albanD

This popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?

we can save and load the model, found from 2385 . Other than this, is there other way which i am missing that will help me make a plausible fix ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that Module is a complex enough class that deepcopying it is very challenging (the same way we don't recommend you serialize it as-is but only the state_dict).
deepcopy() work in most simple cases but it is expected to fail sometimes.
If you only have a regular Conv2d kernel, doing deepcopy or a new constructor is pretty much the same thing though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you want me to proceed with the PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @albanD is saying is that in this specific case deepcopy-ing a conv layer is just fine, i.e. the original code probably doesn't need to be changed.

fused_conv.load_state_dict(conv.state_dict())
fused_conv.eval()

fused_conv.weight, fused_conv.bias = \
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
Expand Down Expand Up @@ -150,7 +152,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc


def fuse(model: torch.nn.Module) -> torch.nn.Module:
model = copy.deepcopy(model)
model, state_dict = type(model)(), model.state_dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's only going to work for models that do not take any parameters to __init__().

@svekars with this + https://github.com/pytorch/tutorials/pull/2645/files#r1391396959, I'm tempted to think that the originally issue is probably irrelevant for this tutorial. Even if copy.deepcopy(model) may not be perfect, it's still better than any alternative that has been proposed so far. Perhaps we could close the original issue and still provide credits to the contributor for their efforts?

model.load_state_dict(state_dict)
model.eval()
# The first step of most FX passes is to symbolically trace our model to
# obtain a `GraphModule`. This is a representation of our original model
# that is functionally identical to our original model, except that we now
Expand Down