@@ -191,6 +191,12 @@ def main():
191
191
192
192
build_strategy = fluid .BuildStrategy ()
193
193
build_strategy .fuse_all_optimizer_ops = True
194
+ try :
195
+ fluid .require_version (min_version = '1.7.0' )
196
+ build_strategy .enable_auto_fusion = args .enable_auto_fusion
197
+ except Exception as e :
198
+ logger .info ("PaddlePaddle version 1.7.0 or higher is "
199
+ "required when you want to enable fusion_group." )
194
200
195
201
if args .parallel :
196
202
train_program = fluid .compiler .CompiledProgram (
@@ -438,32 +444,35 @@ def data_gen():
438
444
print ("ptblm\t lstm_language_model_%s_loss_card%d\t %s" %
439
445
(args .rnn_model , device_count , train_ppl [0 ]))
440
446
441
- # NOTE(zjl): sometimes we have not enough data for eval if batch_size is large, i.e., 2100
442
- # Just skip to avoid error
443
- def is_valid_data (data , batch_size , num_steps ):
444
- data_len = len (data )
445
- batch_len = data_len // batch_size
446
- epoch_size = (batch_len - 1 ) // num_steps
447
- return epoch_size >= 1
448
-
449
- valid_data_valid = is_valid_data (valid_data , config .batch_size ,
450
- config .num_steps )
451
- if valid_data_valid :
452
- valid_ppl = eval (valid_data )
453
- print ("Valid ppl: %.5f" % valid_ppl [0 ])
454
- else :
455
- print (
456
- 'WARNING: length of valid_data is {}, which is not enough for batch_size {} and num_steps {}' .
457
- format (
458
- len (valid_data ), config .batch_size , config .num_steps ))
459
-
460
- save_model_dir = os .path .join (args .save_model_dir , str (epoch_id ))
461
- if not os .path .exists (save_model_dir ):
462
- mkpath (save_model_dir )
463
- save_model_dir = os .path .join (save_model_dir , 'params' )
464
-
465
- fluid .save (main_program , save_model_dir )
466
- print ("Saved model to: %s.\n " % save_model_dir )
447
+ if not args .profile :
448
+ # NOTE(zjl): sometimes we have not enough data for eval if batch_size is large, i.e., 2100
449
+ # Just skip to avoid error
450
+ def is_valid_data (data , batch_size , num_steps ):
451
+ data_len = len (data )
452
+ batch_len = data_len // batch_size
453
+ epoch_size = (batch_len - 1 ) // num_steps
454
+ return epoch_size >= 1
455
+
456
+ valid_data_valid = is_valid_data (valid_data , config .batch_size ,
457
+ config .num_steps )
458
+ if valid_data_valid :
459
+ valid_ppl = eval (valid_data )
460
+ print ("Valid ppl: %.5f" % valid_ppl [0 ])
461
+ else :
462
+ print (
463
+ 'WARNING: length of valid_data is {}, which is not enough for batch_size {} and num_steps {}' .
464
+ format (
465
+ len (valid_data ), config .batch_size ,
466
+ config .num_steps ))
467
+
468
+ save_model_dir = os .path .join (args .save_model_dir ,
469
+ str (epoch_id ))
470
+ if not os .path .exists (save_model_dir ):
471
+ mkpath (save_model_dir )
472
+ save_model_dir = os .path .join (save_model_dir , 'params' )
473
+
474
+ fluid .save (main_program , save_model_dir )
475
+ print ("Saved model to: %s.\n " % save_model_dir )
467
476
468
477
with profile_context (args .profile , args .profiler_path ):
469
478
train ()
0 commit comments