@@ -76,10 +76,16 @@ def parse_args():
76
76
"--validation_file" , type = str , default = None , help = "A csv or a json file containing the validation data."
77
77
)
78
78
parser .add_argument (
79
- "--text_column_name" , type = str , default = None , help = "The column name of text to input in the file (a csv or JSON file)."
79
+ "--text_column_name" ,
80
+ type = str ,
81
+ default = None ,
82
+ help = "The column name of text to input in the file (a csv or JSON file)." ,
80
83
)
81
84
parser .add_argument (
82
- "--label_column_name" , type = str , default = None , help = "The column name of label to input in the file (a csv or JSON file)."
85
+ "--label_column_name" ,
86
+ type = str ,
87
+ default = None ,
88
+ help = "The column name of label to input in the file (a csv or JSON file)." ,
83
89
)
84
90
parser .add_argument (
85
91
"--max_length" ,
@@ -266,17 +272,17 @@ def main():
266
272
column_names = raw_datasets ["validation" ].column_names
267
273
features = raw_datasets ["validation" ].features
268
274
269
- if data_args .text_column_name is not None :
270
- text_column_name = data_args .text_column_name
275
+ if args .text_column_name is not None :
276
+ text_column_name = args .text_column_name
271
277
elif "tokens" in column_names :
272
278
text_column_name = "tokens"
273
279
else :
274
280
text_column_name = column_names [0 ]
275
281
276
- if data_args .label_column_name is not None :
277
- label_column_name = data_args .label_column_name
278
- elif f"{ data_args .task_name } _tags" in column_names :
279
- label_column_name = f"{ data_args .task_name } _tags"
282
+ if args .label_column_name is not None :
283
+ label_column_name = args .label_column_name
284
+ elif f"{ args .task_name } _tags" in column_names :
285
+ label_column_name = f"{ args .task_name } _tags"
280
286
else :
281
287
label_column_name = column_names [1 ]
282
288
0 commit comments