-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace usage of copy.deepcopy() - Convolution/Batch Norm Fuser in FX #2645
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn): | |
module `C` such that C(x) == B(A(x)) in inference mode. | ||
""" | ||
assert(not (conv.training or bn.training)), "Fusion only for eval!" | ||
fused_conv = copy.deepcopy(conv) | ||
fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size) | ||
fused_conv.load_state_dict(conv.state_dict()) | ||
fused_conv.eval() | ||
|
||
fused_conv.weight, fused_conv.bias = \ | ||
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, | ||
|
@@ -150,7 +152,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc | |
|
||
|
||
def fuse(model: torch.nn.Module) -> torch.nn.Module: | ||
model = copy.deepcopy(model) | ||
model, state_dict = type(model)(), model.state_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's only going to work for models that do not take any parameters to @svekars with this + https://github.com/pytorch/tutorials/pull/2645/files#r1391396959, I'm tempted to think that the originally issue is probably irrelevant for this tutorial. Even if |
||
model.load_state_dict(state_dict) | ||
model.eval() | ||
# The first step of most FX passes is to symbolically trace our model to | ||
# obtain a `GraphModule`. This is a representation of our original model | ||
# that is functionally identical to our original model, except that we now | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix seems weird? The right to do feels like its implementing a proper
__deepcopy__()
for nn modules? @albanDThis popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can save and load the model, found from 2385 . Other than this, is there other way which i am missing that will help me make a plausible fix ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that Module is a complex enough class that deepcopying it is very challenging (the same way we don't recommend you serialize it as-is but only the state_dict).
deepcopy() work in most simple cases but it is expected to fail sometimes.
If you only have a regular Conv2d kernel, doing deepcopy or a new constructor is pretty much the same thing though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would you want me to proceed with the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what @albanD is saying is that in this specific case
deepcopy
-ing a conv layer is just fine, i.e. the original code probably doesn't need to be changed.