Skip to content

Commit 0fc6fb6

Browse files
committed
removed deepcopy from fuse method
1 parent f05f050 commit 0fc6fb6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Diff for: intermediate_source/fx_conv_bn_fuser.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc
150150

151151

152152
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153-
model = copy.deepcopy(model)
153+
model, state_dict = type(model)(), model.state_dict()
154+
model.load_state_dict(state_dict)
155+
model.eval()
154156
# The first step of most FX passes is to symbolically trace our model to
155157
# obtain a `GraphModule`. This is a representation of our original model
156158
# that is functionally identical to our original model, except that we now

0 commit comments

Comments
 (0)