Skip to content

Commit d55816b

Browse files
lorenrose1013saran-t
authored andcommitted
Formatting fixes and internal changes.
PiperOrigin-RevId: 414656804
1 parent 826ff89 commit d55816b

6 files changed

+83
-57
lines changed

nfnets/experiment.py

+66-40
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,33 @@ def get_config():
7575
bfloat16=False,
7676
lr_schedule=dict(
7777
name='WarmupCosineDecay',
78-
kwargs=dict(num_steps=config.training_steps,
79-
start_val=0,
80-
min_val=0,
81-
warmup_steps=5*steps_per_epoch),
82-
),
78+
kwargs=dict(
79+
num_steps=config.training_steps,
80+
start_val=0,
81+
min_val=0,
82+
warmup_steps=5 * steps_per_epoch),
83+
),
8384
lr_scale_by_bs=True,
8485
optimizer=dict(
8586
name='SGD',
86-
kwargs={'momentum': 0.9, 'nesterov': True,
87-
'weight_decay': 1e-4,},
87+
kwargs={
88+
'momentum': 0.9,
89+
'nesterov': True,
90+
'weight_decay': 1e-4,
91+
},
8892
),
8993
model_kwargs=dict(
9094
width=4,
9195
which_norm='BatchNorm',
92-
norm_kwargs=dict(create_scale=True,
93-
create_offset=True,
94-
decay_rate=0.9,
95-
), # cross_replica_axis='i'),
96+
norm_kwargs=dict(
97+
create_scale=True,
98+
create_offset=True,
99+
decay_rate=0.9,
100+
), # cross_replica_axis='i'),
96101
variant='ResNet50',
97102
activation='relu',
98103
drop_rate=0.0,
99-
),
104+
),
100105
),))
101106

102107
# Training loop config: log and checkpoint every minute
@@ -136,8 +141,9 @@ def __init__(self, mode, config, init_rng):
136141
self._eval_input = None
137142

138143
# Get model, loaded in from the zoo
144+
module_prefix = 'nfnets.'
139145
self.model_module = importlib.import_module(
140-
('nfnets.'+ self.config.model.lower()))
146+
(module_prefix + self.config.model.lower()))
141147
self.net = hk.transform_with_state(self._forward_fn)
142148

143149
# Assign image sizes
@@ -154,17 +160,17 @@ def __init__(self, mode, config, init_rng):
154160
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
155161

156162
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
157-
self.train_fn = jax.pmap(self._train_fn, axis_name='i',
158-
donate_argnums=donate_argnums)
163+
self.train_fn = jax.pmap(
164+
self._train_fn, axis_name='i', donate_argnums=donate_argnums)
159165
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
160166

161167
def _initialize_train(self):
162168
self._train_input = self._build_train_input()
163169
# Initialize net and EMA copy of net if no params available.
164170
if self._params is None:
165171
inputs = next(self._train_input)
166-
init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
167-
axis_name='i')
172+
init_net = jax.pmap(
173+
lambda *a: self.net.init(*a, is_training=True), axis_name='i')
168174
init_rng = jl_utils.bcast_local_devices(self.init_rng)
169175
self._params, self._state = init_net(init_rng, inputs)
170176
if self.config.use_ema:
@@ -176,8 +182,9 @@ def _initialize_train(self):
176182
def _make_opt(self):
177183
# Separate conv params and gains/biases
178184
def pred(mod, name, val): # pylint:disable=unused-argument
179-
return (name in ['scale', 'offset', 'b']
180-
or 'gain' in name or 'bias' in name)
185+
return (name in ['scale', 'offset', 'b'] or 'gain' in name or
186+
'bias' in name)
187+
181188
gains_biases, weights = hk.data_structures.partition(pred, self._params)
182189
# Lr schedule with batch-based LR scaling
183190
if self.config.lr_scale_by_bs:
@@ -190,16 +197,22 @@ def pred(mod, name, val): # pylint:disable=unused-argument
190197
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
191198
opt_kwargs['lr'] = lr_schedule
192199
opt_module = getattr(optim, self.config.optimizer.name)
193-
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None},
194-
{'params': weights}], **opt_kwargs)
200+
self.opt = opt_module([{
201+
'params': gains_biases,
202+
'weight_decay': None
203+
}, {
204+
'params': weights
205+
}], **opt_kwargs)
195206
if self._opt_state is None:
196207
self._opt_state = self.opt.states()
197208
else:
198209
self.opt.plugin(self._opt_state)
199210

200211
def _forward_fn(self, inputs, is_training):
201-
net_kwargs = {'num_classes': self.config.num_classes,
202-
**self.config.model_kwargs}
212+
net_kwargs = {
213+
'num_classes': self.config.num_classes,
214+
**self.config.model_kwargs
215+
}
203216
net = getattr(self.model_module, self.config.model)(**net_kwargs)
204217
if self.config.get('transpose', False):
205218
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
@@ -236,8 +249,7 @@ def _loss_fn(self, params, state, inputs, rng):
236249
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
237250
return scaled_loss, (metrics, state)
238251

239-
def _train_fn(self, params, states, opt_states,
240-
inputs, rng, global_step,
252+
def _train_fn(self, params, states, opt_states, inputs, rng, global_step,
241253
ema_params, ema_states):
242254
"""Runs one batch forward + backward and run a single opt step."""
243255
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
@@ -260,9 +272,14 @@ def _train_fn(self, params, states, opt_states,
260272
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
261273
ema_params = jax.tree_multimap(ema, ema_params, params)
262274
ema_states = jax.tree_multimap(ema, ema_states, states)
263-
return {'params': params, 'states': states, 'opt_states': opt_states,
264-
'ema_params': ema_params, 'ema_states': ema_states,
265-
'metrics': metrics}
275+
return {
276+
'params': params,
277+
'states': states,
278+
'opt_states': opt_states,
279+
'ema_params': ema_params,
280+
'ema_states': ema_states,
281+
'metrics': metrics
282+
}
266283

267284
# _ _
268285
# | |_ _ __ __ _(_)_ __
@@ -275,10 +292,15 @@ def step(self, global_step, rng, *unused_args, **unused_kwargs):
275292
if self._train_input is None:
276293
self._initialize_train()
277294
inputs = next(self._train_input)
278-
out = self.train_fn(params=self._params, states=self._state,
279-
opt_states=self._opt_state, inputs=inputs,
280-
rng=rng, global_step=global_step,
281-
ema_params=self._ema_params, ema_states=self._ema_state)
295+
out = self.train_fn(
296+
params=self._params,
297+
states=self._state,
298+
opt_states=self._opt_state,
299+
inputs=inputs,
300+
rng=rng,
301+
global_step=global_step,
302+
ema_params=self._ema_params,
303+
ema_states=self._ema_state)
282304
self._params, self._state = out['params'], out['states']
283305
self._opt_state = out['opt_states']
284306
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
@@ -294,7 +316,8 @@ def _build_train_input(self):
294316
f'Global batch size {global_batch_size} must be divisible by '
295317
f'num devices {num_devices}')
296318
return dataset.load(
297-
dataset.Split.TRAIN_AND_VALID, is_training=True,
319+
dataset.Split.TRAIN_AND_VALID,
320+
is_training=True,
298321
batch_dims=[jax.local_device_count(), bs_per_device],
299322
transpose=self.config.get('transpose', False),
300323
image_size=(self.train_imsize,) * 2,
@@ -350,13 +373,16 @@ def _build_eval_input(self):
350373
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
351374
split = dataset.Split.from_string(self.config.eval_subset)
352375
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
353-
return dataset.load(split, is_training=False,
354-
batch_dims=[jax.local_device_count(), bs_per_device],
355-
transpose=self.config.get('transpose', False),
356-
image_size=(self.test_imsize,) * 2,
357-
name=self.config.which_dataset,
358-
eval_preproc=eval_preproc,
359-
fake_data=self.config.get('fake_data', False))
376+
return dataset.load(
377+
split,
378+
is_training=False,
379+
batch_dims=[jax.local_device_count(), bs_per_device],
380+
transpose=self.config.get('transpose', False),
381+
image_size=(self.test_imsize,) * 2,
382+
name=self.config.which_dataset,
383+
eval_preproc=eval_preproc,
384+
fake_data=self.config.get('fake_data', False))
385+
360386

361387
if __name__ == '__main__':
362388
flags.mark_flag_as_required('config')

nfnets/experiment_nfnets.py

+12
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414
# ==============================================================================
1515
r"""ImageNet experiment with NFNets."""
1616

17+
import sys
18+
19+
from absl import flags
1720
import haiku as hk
21+
from jaxline import platform
1822
from ml_collections import config_dict
23+
1924
from nfnets import experiment
2025
from nfnets import optim
2126

27+
FLAGS = flags.FLAGS
28+
2229

2330
def get_config():
2431
"""Return config object for training."""
@@ -124,3 +131,8 @@ def pred_fc(mod, name, val):
124131
self._opt_state = self.opt.states()
125132
else:
126133
self.opt.plugin(self._opt_state)
134+
135+
136+
if __name__ == '__main__':
137+
flags.mark_flag_as_required('config')
138+
platform.main(Experiment, sys.argv[1:])

nfnets/fixup_resnet.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,8 @@ def count_flops(self, h, w):
121121
flops += [block.count_flops(h, w)]
122122
if block.stride > 1:
123123
h, w = h / block.stride, w / block.stride
124-
# Head module FLOPs
125-
out_ch = self.blocks[-1].out_ch
126-
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
127124
# Count flops for classifier
128-
flops += [self.final_conv.output_channels * self.fc.output_size]
125+
flops += [self.blocks[-1].out_ch * self.fc.output_size]
129126
return flops, sum(flops)
130127

131128

@@ -213,4 +210,3 @@ def count_flops(self, h, w):
213210
sc_flops = 0
214211
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
215212
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
216-

nfnets/nf_resnet.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, num_classes, variant='ResNet50', width=4,
8181
)]
8282
ch = block_width
8383
index += 1
84-
# Reset expected std but still give it 1 block of growth
84+
# Reset expected std but still give it 1 block of growth
8585
if block_index == 0:
8686
expected_std = 1.0
8787
expected_std = (expected_std **2 + alpha**2)**0.5
@@ -124,11 +124,8 @@ def count_flops(self, h, w):
124124
flops += [block.count_flops(h, w)]
125125
if block.stride > 1:
126126
h, w = h / block.stride, w / block.stride
127-
# Head module FLOPs
128-
out_ch = self.blocks[-1].out_ch
129-
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
130127
# Count flops for classifier
131-
flops += [self.final_conv.output_channels * self.fc.output_size]
128+
flops += [self.blocks[-1].out_ch * self.fc.output_size]
132129
return flops, sum(flops)
133130

134131

@@ -213,4 +210,3 @@ def count_flops(self, h, w):
213210
se_flops += self.se.fc0.output_size * self.se.fc1.output_size
214211
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
215212
return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
216-

nfnets/optim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __getattr__(self, name):
132132
elif '_states' in self.__dict__ and name in self._states:
133133
return self._states[name]
134134
else:
135-
object.__getattr__(self, name)
135+
object.__getattribute__(self, name)
136136

137137
def step(self, params, grads, states, itr=None):
138138
"""Takes a single optimizer step.

nfnets/skipinit_resnet.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,8 @@ def count_flops(self, h, w):
117117
flops += [block.count_flops(h, w)]
118118
if block.stride > 1:
119119
h, w = h / block.stride, w / block.stride
120-
# Head module FLOPs
121-
out_ch = self.blocks[-1].out_ch
122-
flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
123120
# Count flops for classifier
124-
flops += [self.final_conv.output_channels * self.fc.output_size]
121+
flops += [self.blocks[-1].out_ch * self.fc.output_size]
125122
return flops, sum(flops)
126123

127124

@@ -191,4 +188,3 @@ def count_flops(self, h, w):
191188
# SE flops happen on avg-pooled activations
192189
contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
193190
return sum([expand_flops, dw_flops, contract_flops, sc_flops])
194-

0 commit comments

Comments
 (0)