Skip to content

Commit a979002

Browse files
Add quantization/8bit model loading support for sampling_report.py (LAION-AI#2857)
Add ---quantize to the script call to take effect Tested using: 1. --quantize --model-name OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 2. --quantize --model-name t5-small --model-type t5conditional 3. --quantize --model-name OpenAssistant/stablelm-7b-sft-v7-epoch-3 Unable to test on llama models without access to the base weights (and/or only 35GB of VRAM?) Enjoy, TP
1 parent 322cf35 commit a979002

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

model/model_eval/manual/sampling_report.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def parse_args():
226226
parser.add_argument("--num-samples", type=int, default=2, help="number of sampling runs per configuration")
227227
parser.add_argument("--config", type=str, default="config/default.json", help="configuration file path")
228228
parser.add_argument("--half", action="store_true", default=False, help="use float16")
229+
parser.add_argument("--int8", action="store_true", default=False, help="use int8 quantization")
229230
parser.add_argument("--skip-special-tokens", action="store_true", default=False)
230231
parser.add_argument("--model-type", type=str, default="CausalLM", help="CausalLM, T5Conditional, LLaMA")
231232
parser.add_argument("--max-input-len", type=int, help="max token counts for input")
@@ -247,6 +248,10 @@ def main():
247248
print("Using pytorch version {}".format(torch.__version__))
248249

249250
args = parse_args()
251+
if args.int8 and not torch.cuda.is_available():
252+
print("Warning: --int8 argument passed but cuda is not available. Ignoring --int8.")
253+
args.int8 = False
254+
250255
print("Args:", args)
251256

252257
torch.set_num_threads(args.num_threads)
@@ -265,17 +270,23 @@ def main():
265270
model_name = args.model_name
266271
print(f"Loading model: {model_name}")
267272

273+
model_args = {}
274+
if args.int8:
275+
# these will break model.to(device) later in the script so a conditional check is needed
276+
model_args["load_in_8bit"] = args.int8
277+
model_args["device_map"] = "auto"
278+
268279
if args.model_type.lower() == "causallm" or args.model_type.lower() == "llama":
269280
from transformers import AutoModelForCausalLM
270281

271282
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=args.auth_token)
272-
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=args.auth_token)
283+
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=args.auth_token, **model_args)
273284
skip_input_tokens = True
274285
elif args.model_type.lower() == "t5conditional":
275286
from transformers import T5ForConditionalGeneration
276287

277288
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=args.auth_token)
278-
model = T5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=args.auth_token)
289+
model = T5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=args.auth_token, **model_args)
279290
skip_input_tokens = False
280291
else:
281292
raise RuntimeError("Invalid model_type specified")
@@ -293,7 +304,10 @@ def main():
293304
model.eval()
294305
if args.half:
295306
model = model.half()
296-
model = model.to(device)
307+
308+
# int8 models (load_in_8bit = True + device_map = auto): will cause this method to error
309+
if not args.int8:
310+
model = model.to(device)
297311

298312
print(f"Loading prompts file: {args.prompts}")
299313
prompts = load_jsonl(input_file_path=args.prompts)

0 commit comments

Comments
 (0)