@@ -226,6 +226,7 @@ def parse_args():
226
226
parser .add_argument ("--num-samples" , type = int , default = 2 , help = "number of sampling runs per configuration" )
227
227
parser .add_argument ("--config" , type = str , default = "config/default.json" , help = "configuration file path" )
228
228
parser .add_argument ("--half" , action = "store_true" , default = False , help = "use float16" )
229
+ parser .add_argument ("--int8" , action = "store_true" , default = False , help = "use int8 quantization" )
229
230
parser .add_argument ("--skip-special-tokens" , action = "store_true" , default = False )
230
231
parser .add_argument ("--model-type" , type = str , default = "CausalLM" , help = "CausalLM, T5Conditional, LLaMA" )
231
232
parser .add_argument ("--max-input-len" , type = int , help = "max token counts for input" )
@@ -247,6 +248,10 @@ def main():
247
248
print ("Using pytorch version {}" .format (torch .__version__ ))
248
249
249
250
args = parse_args ()
251
+ if args .int8 and not torch .cuda .is_available ():
252
+ print ("Warning: --int8 argument passed but cuda is not available. Ignoring --int8." )
253
+ args .int8 = False
254
+
250
255
print ("Args:" , args )
251
256
252
257
torch .set_num_threads (args .num_threads )
@@ -265,17 +270,23 @@ def main():
265
270
model_name = args .model_name
266
271
print (f"Loading model: { model_name } " )
267
272
273
+ model_args = {}
274
+ if args .int8 :
275
+ # these will break model.to(device) later in the script so a conditional check is needed
276
+ model_args ["load_in_8bit" ] = args .int8
277
+ model_args ["device_map" ] = "auto"
278
+
268
279
if args .model_type .lower () == "causallm" or args .model_type .lower () == "llama" :
269
280
from transformers import AutoModelForCausalLM
270
281
271
282
tokenizer = AutoTokenizer .from_pretrained (model_name , use_auth_token = args .auth_token )
272
- model = AutoModelForCausalLM .from_pretrained (model_name , use_auth_token = args .auth_token )
283
+ model = AutoModelForCausalLM .from_pretrained (model_name , use_auth_token = args .auth_token , ** model_args )
273
284
skip_input_tokens = True
274
285
elif args .model_type .lower () == "t5conditional" :
275
286
from transformers import T5ForConditionalGeneration
276
287
277
288
tokenizer = AutoTokenizer .from_pretrained (model_name , use_auth_token = args .auth_token )
278
- model = T5ForConditionalGeneration .from_pretrained (model_name , use_auth_token = args .auth_token )
289
+ model = T5ForConditionalGeneration .from_pretrained (model_name , use_auth_token = args .auth_token , ** model_args )
279
290
skip_input_tokens = False
280
291
else :
281
292
raise RuntimeError ("Invalid model_type specified" )
@@ -293,7 +304,10 @@ def main():
293
304
model .eval ()
294
305
if args .half :
295
306
model = model .half ()
296
- model = model .to (device )
307
+
308
+ # int8 models (load_in_8bit = True + device_map = auto): will cause this method to error
309
+ if not args .int8 :
310
+ model = model .to (device )
297
311
298
312
print (f"Loading prompts file: { args .prompts } " )
299
313
prompts = load_jsonl (input_file_path = args .prompts )
0 commit comments