Skip to content

Commit bd004fd

Browse files
stephenyan1231facebook-github-bot
authored andcommitted
do not reduce batch norm running var by eps
Summary: Currently, in `FrozenBatchNorm2d`, we implement to reduce running var by eps when version is smaller than 3 when loading checkpoint. In Pytorch `BatchNorm2D` module, version is 2 (https://fburl.com/diffusion/lz7jhjcv). Both `FrozenBatchNorm2d` and `BatchNorm2D` use `F.batch_norm` with the same eps argument to implement forward pass. Therefore, we should NOT reduce running var by eps when version is smaller than 3 when loading checkpoint. Otherwise, when running var is close to zero in the checkpoint, it will incur NaN issue after we load checkpoint. Reviewed By: ppwwyyxx Differential Revision: D29190708 fbshipit-source-id: ddd98080ce9c108768d3a5a03e28522208196385
1 parent d285dea commit bd004fd

File tree

1 file changed

+0
-9
lines changed

1 file changed

+0
-9
lines changed

detectron2/layers/batch_norm.py

-9
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,6 @@ def _load_from_state_dict(
7878
if prefix + "running_var" not in state_dict:
7979
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
8080

81-
# NOTE: if a checkpoint is trained with BatchNorm and loaded (together with
82-
# version number) to FrozenBatchNorm, running_var will be wrong. One solution
83-
# is to remove the version number from the checkpoint.
84-
if version is not None and version < 3:
85-
logger = logging.getLogger(__name__)
86-
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
87-
# In version < 3, running_var are used without +eps.
88-
state_dict[prefix + "running_var"] -= self.eps
89-
9081
super()._load_from_state_dict(
9182
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
9283
)

0 commit comments

Comments
 (0)