@@ -75,28 +75,33 @@ def get_config():
75
75
bfloat16 = False ,
76
76
lr_schedule = dict (
77
77
name = 'WarmupCosineDecay' ,
78
- kwargs = dict (num_steps = config .training_steps ,
79
- start_val = 0 ,
80
- min_val = 0 ,
81
- warmup_steps = 5 * steps_per_epoch ),
82
- ),
78
+ kwargs = dict (
79
+ num_steps = config .training_steps ,
80
+ start_val = 0 ,
81
+ min_val = 0 ,
82
+ warmup_steps = 5 * steps_per_epoch ),
83
+ ),
83
84
lr_scale_by_bs = True ,
84
85
optimizer = dict (
85
86
name = 'SGD' ,
86
- kwargs = {'momentum' : 0.9 , 'nesterov' : True ,
87
- 'weight_decay' : 1e-4 ,},
87
+ kwargs = {
88
+ 'momentum' : 0.9 ,
89
+ 'nesterov' : True ,
90
+ 'weight_decay' : 1e-4 ,
91
+ },
88
92
),
89
93
model_kwargs = dict (
90
94
width = 4 ,
91
95
which_norm = 'BatchNorm' ,
92
- norm_kwargs = dict (create_scale = True ,
93
- create_offset = True ,
94
- decay_rate = 0.9 ,
95
- ), # cross_replica_axis='i'),
96
+ norm_kwargs = dict (
97
+ create_scale = True ,
98
+ create_offset = True ,
99
+ decay_rate = 0.9 ,
100
+ ), # cross_replica_axis='i'),
96
101
variant = 'ResNet50' ,
97
102
activation = 'relu' ,
98
103
drop_rate = 0.0 ,
99
- ),
104
+ ),
100
105
),))
101
106
102
107
# Training loop config: log and checkpoint every minute
@@ -136,8 +141,9 @@ def __init__(self, mode, config, init_rng):
136
141
self ._eval_input = None
137
142
138
143
# Get model, loaded in from the zoo
144
+ module_prefix = 'nfnets.'
139
145
self .model_module = importlib .import_module (
140
- ('nfnets.' + self .config .model .lower ()))
146
+ (module_prefix + self .config .model .lower ()))
141
147
self .net = hk .transform_with_state (self ._forward_fn )
142
148
143
149
# Assign image sizes
@@ -154,17 +160,17 @@ def __init__(self, mode, config, init_rng):
154
160
self .test_imsize = variant_dict .get ('test_imsize' , test_imsize )
155
161
156
162
donate_argnums = (0 , 1 , 2 , 6 , 7 ) if self .config .use_ema else (0 , 1 , 2 )
157
- self .train_fn = jax .pmap (self . _train_fn , axis_name = 'i' ,
158
- donate_argnums = donate_argnums )
163
+ self .train_fn = jax .pmap (
164
+ self . _train_fn , axis_name = 'i' , donate_argnums = donate_argnums )
159
165
self .eval_fn = jax .pmap (self ._eval_fn , axis_name = 'i' )
160
166
161
167
def _initialize_train (self ):
162
168
self ._train_input = self ._build_train_input ()
163
169
# Initialize net and EMA copy of net if no params available.
164
170
if self ._params is None :
165
171
inputs = next (self ._train_input )
166
- init_net = jax .pmap (lambda * a : self . net . init ( * a , is_training = True ),
167
- axis_name = 'i' )
172
+ init_net = jax .pmap (
173
+ lambda * a : self . net . init ( * a , is_training = True ), axis_name = 'i' )
168
174
init_rng = jl_utils .bcast_local_devices (self .init_rng )
169
175
self ._params , self ._state = init_net (init_rng , inputs )
170
176
if self .config .use_ema :
@@ -176,8 +182,9 @@ def _initialize_train(self):
176
182
def _make_opt (self ):
177
183
# Separate conv params and gains/biases
178
184
def pred (mod , name , val ): # pylint:disable=unused-argument
179
- return (name in ['scale' , 'offset' , 'b' ]
180
- or 'gain' in name or 'bias' in name )
185
+ return (name in ['scale' , 'offset' , 'b' ] or 'gain' in name or
186
+ 'bias' in name )
187
+
181
188
gains_biases , weights = hk .data_structures .partition (pred , self ._params )
182
189
# Lr schedule with batch-based LR scaling
183
190
if self .config .lr_scale_by_bs :
@@ -190,16 +197,22 @@ def pred(mod, name, val): # pylint:disable=unused-argument
190
197
opt_kwargs = {key : val for key , val in self .config .optimizer .kwargs .items ()}
191
198
opt_kwargs ['lr' ] = lr_schedule
192
199
opt_module = getattr (optim , self .config .optimizer .name )
193
- self .opt = opt_module ([{'params' : gains_biases , 'weight_decay' : None },
194
- {'params' : weights }], ** opt_kwargs )
200
+ self .opt = opt_module ([{
201
+ 'params' : gains_biases ,
202
+ 'weight_decay' : None
203
+ }, {
204
+ 'params' : weights
205
+ }], ** opt_kwargs )
195
206
if self ._opt_state is None :
196
207
self ._opt_state = self .opt .states ()
197
208
else :
198
209
self .opt .plugin (self ._opt_state )
199
210
200
211
def _forward_fn (self , inputs , is_training ):
201
- net_kwargs = {'num_classes' : self .config .num_classes ,
202
- ** self .config .model_kwargs }
212
+ net_kwargs = {
213
+ 'num_classes' : self .config .num_classes ,
214
+ ** self .config .model_kwargs
215
+ }
203
216
net = getattr (self .model_module , self .config .model )(** net_kwargs )
204
217
if self .config .get ('transpose' , False ):
205
218
images = jnp .transpose (inputs ['images' ], (3 , 0 , 1 , 2 )) # HWCN -> NHWC
@@ -236,8 +249,7 @@ def _loss_fn(self, params, state, inputs, rng):
236
249
scaled_loss = loss / jax .device_count () # Grads get psummed so do divide
237
250
return scaled_loss , (metrics , state )
238
251
239
- def _train_fn (self , params , states , opt_states ,
240
- inputs , rng , global_step ,
252
+ def _train_fn (self , params , states , opt_states , inputs , rng , global_step ,
241
253
ema_params , ema_states ):
242
254
"""Runs one batch forward + backward and run a single opt step."""
243
255
grad_fn = jax .grad (self ._loss_fn , argnums = 0 , has_aux = True )
@@ -260,9 +272,14 @@ def _train_fn(self, params, states, opt_states,
260
272
ema = lambda x , y : ema_fn (x , y , self .config .ema_decay , global_step )
261
273
ema_params = jax .tree_multimap (ema , ema_params , params )
262
274
ema_states = jax .tree_multimap (ema , ema_states , states )
263
- return {'params' : params , 'states' : states , 'opt_states' : opt_states ,
264
- 'ema_params' : ema_params , 'ema_states' : ema_states ,
265
- 'metrics' : metrics }
275
+ return {
276
+ 'params' : params ,
277
+ 'states' : states ,
278
+ 'opt_states' : opt_states ,
279
+ 'ema_params' : ema_params ,
280
+ 'ema_states' : ema_states ,
281
+ 'metrics' : metrics
282
+ }
266
283
267
284
# _ _
268
285
# | |_ _ __ __ _(_)_ __
@@ -275,10 +292,15 @@ def step(self, global_step, rng, *unused_args, **unused_kwargs):
275
292
if self ._train_input is None :
276
293
self ._initialize_train ()
277
294
inputs = next (self ._train_input )
278
- out = self .train_fn (params = self ._params , states = self ._state ,
279
- opt_states = self ._opt_state , inputs = inputs ,
280
- rng = rng , global_step = global_step ,
281
- ema_params = self ._ema_params , ema_states = self ._ema_state )
295
+ out = self .train_fn (
296
+ params = self ._params ,
297
+ states = self ._state ,
298
+ opt_states = self ._opt_state ,
299
+ inputs = inputs ,
300
+ rng = rng ,
301
+ global_step = global_step ,
302
+ ema_params = self ._ema_params ,
303
+ ema_states = self ._ema_state )
282
304
self ._params , self ._state = out ['params' ], out ['states' ]
283
305
self ._opt_state = out ['opt_states' ]
284
306
self ._ema_params , self ._ema_state = out ['ema_params' ], out ['ema_states' ]
@@ -294,7 +316,8 @@ def _build_train_input(self):
294
316
f'Global batch size { global_batch_size } must be divisible by '
295
317
f'num devices { num_devices } ' )
296
318
return dataset .load (
297
- dataset .Split .TRAIN_AND_VALID , is_training = True ,
319
+ dataset .Split .TRAIN_AND_VALID ,
320
+ is_training = True ,
298
321
batch_dims = [jax .local_device_count (), bs_per_device ],
299
322
transpose = self .config .get ('transpose' , False ),
300
323
image_size = (self .train_imsize ,) * 2 ,
@@ -350,13 +373,16 @@ def _build_eval_input(self):
350
373
bs_per_device = (self .config .eval_batch_size // jax .local_device_count ())
351
374
split = dataset .Split .from_string (self .config .eval_subset )
352
375
eval_preproc = self .config .get ('eval_preproc' , 'crop_resize' )
353
- return dataset .load (split , is_training = False ,
354
- batch_dims = [jax .local_device_count (), bs_per_device ],
355
- transpose = self .config .get ('transpose' , False ),
356
- image_size = (self .test_imsize ,) * 2 ,
357
- name = self .config .which_dataset ,
358
- eval_preproc = eval_preproc ,
359
- fake_data = self .config .get ('fake_data' , False ))
376
+ return dataset .load (
377
+ split ,
378
+ is_training = False ,
379
+ batch_dims = [jax .local_device_count (), bs_per_device ],
380
+ transpose = self .config .get ('transpose' , False ),
381
+ image_size = (self .test_imsize ,) * 2 ,
382
+ name = self .config .which_dataset ,
383
+ eval_preproc = eval_preproc ,
384
+ fake_data = self .config .get ('fake_data' , False ))
385
+
360
386
361
387
if __name__ == '__main__' :
362
388
flags .mark_flag_as_required ('config' )
0 commit comments