-
Notifications
You must be signed in to change notification settings - Fork 271
/
Copy pathllm_optimize_woq.py
152 lines (139 loc) · 5.16 KB
/
llm_optimize_woq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
###################################################### # noqa F401
import argparse
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
# args
parser = argparse.ArgumentParser(
"Generation script (weight only quantization path)", add_help=False
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "bfloat16"],
default="float32",
help="choose the weight dtype and whether to enable auto mixed precision or not",
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--prompt", default="What are we having for dinner?", type=str, help="input prompt"
)
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
parser.add_argument(
"--lowp-mode",
choices=["AUTO", "BF16", "FP32", "INT8", "FP16"],
default="AUTO",
type=str,
help="low precision mode for weight only quantization. "
"It indicates data type for computation for speedup at the cost "
"of accuracy. Unrelated to activation or weight data type."
"It is not supported yet to use lowp_mode=INT8 for INT8 weight, "
"falling back to lowp_mode=BF16 implicitly in this case."
"If set to AUTO, lowp_mode is determined by weight data type: "
"lowp_mode=BF16 is used for INT8 weight "
"and lowp_mode=INT8 used for INT4 weight",
)
parser.add_argument(
"--weight-dtype",
choices=["INT8", "INT4"],
default="INT8",
type=str,
help="weight data type for weight only quantization. Unrelated to activation"
" data type or lowp-mode. If `--low-precision-checkpoint` is given, weight"
" data type is always INT4 and this argument is not needed.",
)
parser.add_argument(
"--low-precision-checkpoint",
default="",
type=str,
help="Low precision checkpoint file generated by calibration, such as GPTQ. It contains"
" modified weights, scales, zero points, etc. For better accuracy of weight only"
" quantization with INT4 weight.",
)
###################################################### # noqa F401
args = parser.parse_args()
print(args)
# dtype
amp_enabled = True if args.dtype != "float32" else False
amp_dtype = getattr(torch, args.dtype)
# load model
model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(model_id, torchscript=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=amp_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = model.eval()
model = model.to(memory_format=torch.channels_last)
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
from intel_extension_for_pytorch.quantization import WoqWeightDtype
weight_dtype = (
WoqWeightDtype.INT4 if args.weight_dtype == "INT4" else WoqWeightDtype.INT8
)
if args.lowp_mode == "INT8":
lowp_mode = ipex.quantization.WoqLowpMode.INT8
elif args.lowp_mode == "FP32":
lowp_mode = ipex.quantization.WoqLowpMode.NONE
elif args.lowp_mode == "FP16":
lowp_mode = ipex.quantization.WoqLowpMode.FP16
elif args.lowp_mode == "BF16":
lowp_mode = ipex.quantization.WoqLowpMode.BF16
else: # AUTO
if args.low_precision_checkpoint != "" or weight_dtype == WoqWeightDtype.INT4:
lowp_mode = ipex.quantization.WoqLowpMode.INT8
else:
lowp_mode = ipex.quantization.WoqLowpMode.BF16
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype, lowp_mode=lowp_mode
)
if args.low_precision_checkpoint != "":
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
else:
low_precision_checkpoint = None
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
low_precision_checkpoint=low_precision_checkpoint,
deployment_mode=True,
inplace=True,
)
###################################################### # noqa F401
# generate args
num_beams = 1 if args.greedy else 4
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)
# input prompt
prompt = args.prompt
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
prompt = [prompt] * args.batch_size
# inference
with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast(
enabled=amp_enabled
):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_ids = model.generate(
input_ids, max_new_tokens=args.max_new_tokens, **generate_kwargs
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)