7
7
from model_training .custom_datasets .ranking_collator import RankingDataCollator
8
8
from model_training .metrics import RewardMetrics
9
9
from torch .utils .data import DataLoader
10
+ from tqdm import tqdm
10
11
from transformers import AutoModelForSequenceClassification , AutoTokenizer
11
12
from transformers .trainer_utils import EvalPrediction
12
13
from utils import write_to_json
@@ -29,15 +30,16 @@ def get_ranking_dataset(dataset, split):
29
30
def batch_inference (inputs , model ):
30
31
batch , cu_lens = inputs
31
32
batch = {k : v .to (model .device ) for k , v in batch .items ()}
32
- logits = (
33
- model (
34
- input_ids = batch ["input_ids" ],
35
- attention_mask = batch ["attention_mask" ],
36
- )
37
- .logits .detach ()
38
- .cpu ()
39
- .numpy ()
40
- )
33
+
34
+ with torch .no_grad ():
35
+ logits = model (input_ids = batch ["input_ids" ], attention_mask = batch ["attention_mask" ]).logits .detach ().cpu ()
36
+
37
+ if logits .dtype == torch .bfloat16 :
38
+ # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
39
+ # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
40
+ # Until Numpy adds bfloat16, we must convert float32.
41
+ logits = logits .to (torch .float32 )
42
+ logits = logits .numpy ()
41
43
42
44
labels = []
43
45
for i , (s , e ) in enumerate (zip (cu_lens [:- 1 ], cu_lens [1 :])):
@@ -54,6 +56,7 @@ def batch_inference(inputs, model):
54
56
parser .add_argument ("--metrics" , type = str , help = "metrics to evaluate" , default = "accuracy" )
55
57
parser .add_argument ("--batch_size" , type = int , help = "Batch Size" , default = 8 )
56
58
parser .add_argument ("--device" , type = str , help = "device" , default = "cuda" )
59
+ parser .add_argument ("--dtype" , type = str , help = "data type" , default = None )
57
60
args = parser .parse_args ().__dict__
58
61
59
62
if args .get ("device" ) != "cpu" :
@@ -64,7 +67,9 @@ def batch_inference(inputs, model):
64
67
model_name = args .get ("model" )
65
68
66
69
tokenizer = AutoTokenizer .from_pretrained (model_name )
67
- model = AutoModelForSequenceClassification .from_pretrained (model_name )
70
+ model = AutoModelForSequenceClassification .from_pretrained (
71
+ model_name , torch_dtype = "auto" if not args .dtype else args .dtype
72
+ )
68
73
model .eval ()
69
74
model .to (device )
70
75
max_length = args .get ("max_length" ) or model .config .max_position_embeddings
@@ -77,7 +82,7 @@ def batch_inference(inputs, model):
77
82
metrics = args .get ("metrics" ).split ("," )
78
83
compute_metrics = RewardMetrics (metrics )
79
84
score_dict = defaultdict (float )
80
- for i , data in enumerate (dataset ):
85
+ for i , data in enumerate (tqdm ( dataset ) ):
81
86
eval_pred = batch_inference (data , model )
82
87
results = compute_metrics (eval_pred )
83
88
for metric in metrics :
0 commit comments