@@ -93,11 +93,11 @@ class DataTrainingArguments:
93
93
overwrite_cache : bool = field (
94
94
default = False , metadata = {"help" : "Overwrite the cached training and evaluation sets" }
95
95
)
96
- max_seq_length : Optional [ int ] = field (
97
- default = None ,
96
+ max_seq_length : int = field (
97
+ default = 512 ,
98
98
metadata = {
99
99
"help" : "The maximum total input sequence length after tokenization. Sequences longer "
100
- "than this will be truncated. Default to the max input length of the model. "
100
+ "than this will be truncated."
101
101
},
102
102
)
103
103
preprocessing_num_workers : Optional [int ] = field (
@@ -286,15 +286,12 @@ def tokenize_function(examples):
286
286
load_from_cache_file = not data_args .overwrite_cache ,
287
287
)
288
288
289
- if data_args .max_seq_length is None :
290
- max_seq_length = tokenizer .model_max_length
291
- else :
292
- if data_args .max_seq_length > tokenizer .model_max_length :
293
- logger .warn (
294
- f"The max_seq_length passed ({ data_args .max_seq_length } ) is larger than the maximum length for the"
295
- f"model ({ tokenizer .model_max_length } ). Using max_seq_length={ tokenizer .model_max_length } ."
296
- )
297
- max_seq_length = min (data_args .max_seq_length , tokenizer .model_max_length )
289
+ if data_args .max_seq_length > tokenizer .model_max_length :
290
+ logger .warn (
291
+ f"The max_seq_length passed ({ data_args .max_seq_length } ) is larger than the maximum length for the"
292
+ f"model ({ tokenizer .model_max_length } ). Using max_seq_length={ tokenizer .model_max_length } ."
293
+ )
294
+ max_seq_length = min (data_args .max_seq_length , tokenizer .model_max_length )
298
295
299
296
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
300
297
# max_seq_length.
0 commit comments