Skip to content

Commit dba8a76

Browse files
authored
Merge pull request #25 from DreSimpelo/patch-3
Update to batch_norm (Unwanted ema updates)
2 parents afbdfe5 + 1e327bd commit dba8a76

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/libs/batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def batch_norm(x, phase_train, scope='bn', affine=True):
4141

4242
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
4343
ema = tf.train.ExponentialMovingAverage(decay=0.9)
44-
ema_apply_op = ema.apply([batch_mean, batch_var])
4544
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
4645

4746
def mean_var_with_update():
@@ -52,6 +51,7 @@ def mean_var_with_update():
5251
name : TYPE
5352
Description
5453
"""
54+
ema_apply_op = ema.apply([batch_mean, batch_var])
5555
with tf.control_dependencies([ema_apply_op]):
5656
return tf.identity(batch_mean), tf.identity(batch_var)
5757
mean, var = control_flow_ops.cond(phase_train,

0 commit comments

Comments
 (0)