-
Notifications
You must be signed in to change notification settings - Fork 271
/
Copy pathllm_optimize_smoothquant.py
240 lines (218 loc) · 7.59 KB
/
llm_optimize_smoothquant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
###################################################### # noqa F401
import argparse
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
# args
parser = argparse.ArgumentParser(
"Generation script (static quantization path)", add_help=False
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "bfloat16"],
default="float32",
help="choose the weight dtype and whether to enable auto mixed precision or not",
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--prompt", default="What are we having for dinner?", type=str, help="input prompt"
)
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
parser.add_argument("--calibration", action="store_true")
parser.add_argument(
"--calibration-samples",
default=512,
type=int,
help="total number of calibration samples",
)
parser.add_argument(
"--int8-qconfig",
nargs="?",
default="./qconfig.json",
help="static quantization factors summary files generated by calibration",
)
parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k")
parser.add_argument(
"--alpha", default=0.5, type=float, help="alpha value for smoothquant"
)
args = parser.parse_args()
print(args)
# dtype
amp_enabled = True if args.dtype != "float32" and not calibration else False
amp_dtype = getattr(torch, args.dtype) if not calibration else torch.float32
# load model
model_id = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_id, torchscript=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=amp_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = model.eval()
model = model.to(memory_format=torch.channels_last)
num_beams = 1 if args.greedy else 4
beam_idx_tmp = torch.zeros(
(2048, int(args.batch_size * num_beams)), dtype=torch.long
).contiguous()
global_past_key_value = [
(
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros(
[
1,
model.config.num_attention_heads,
1,
int(model.config.hidden_size / model.config.num_attention_heads),
]
).contiguous(),
torch.zeros(
[
1,
user_model.config.num_attention_heads,
1,
int(model.config.hidden_size / model.config.num_attention_heads),
]
).contiguous(),
beam_idx_tmp,
)
for i in range(model.config.num_hidden_layers)
]
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
class Calibration:
def __init__(self, dataset, tokenizer, batch_size=1, pad_val=1, pad_max=512):
self.dataset = dataset
self.tokenizer = tokenizer
self.batch_size = batch_size
self.pad_val = pad_val
self.pad_max = pad_max
# tokenize the dataset
self.dataset = self.dataset.map(self.tokenize_function, batched=True)
self.dataset.set_format(type="torch", columns=["input_ids"])
@torch.no_grad()
def tokenize_function(self, examples):
if "prompt" in examples:
example = self.tokenizer(examples["prompt"])
elif "text" in examples:
example = self.tokenizer(examples["text"])
elif "code" in examples:
example = self.tokenizer(examples["code"])
return example
@torch.no_grad()
def collate_batch(self, batch):
position_ids_padded = []
input_ids_padded = []
last_ind = []
attention_mask_padded = []
for text in batch:
input_ids = text["input_ids"]
input_ids = (
input_ids[: int(self.pad_max)]
if len(input_ids) > int(self.pad_max)
else input_ids
)
last_ind.append(input_ids.shape[0] - 1)
attention_mask = torch.ones(len(input_ids))
position_ids = torch.arange(len(input_ids))
input_ids_padded.append(input_ids)
attention_mask_padded.append(attention_mask)
position_ids_padded.append(position_ids)
return (
(
torch.vstack(input_ids_padded),
torch.vstack(attention_mask_padded),
torch.vstack(position_ids_padded),
tuple(global_past_key_value),
),
torch.tensor(last_ind),
)
calib_dataset = load_dataset(args.dataset, split="train")
calib_evaluator = Calibration(calib_dataset, tokenizer, args.batch_size)
calib_dataloader = DataLoader(
calib_evaluator.dataset,
batch_size=1,
shuffle=False,
collate_fn=calib_evaluator.collate_batch,
)
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha)
if args.calibration:
example_inputs = None
for i, (
(input_ids, attention_mask, position_ids, past_key_values),
last_ind,
) in enumerate(calib_dataloader):
example_inputs = (input_ids, attention_mask, position_ids, past_key_values)
break
from intel_extension_for_pytorch.quantization import prepare
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
inplace=True,
deployment_mode=False,
)
prepared_model = prepare(model.eval(), qconfig, example_inputs=example_inputs)
with torch.no_grad():
for i, (
(input_ids, attention_mask, position_ids, past_key_values),
last_ind,
) in enumerate(calib_dataloader):
if i == args.calibration_samples:
break
prepared_model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
prepared_model.save_qconf_summary(qconf_summary=args.int8_qconfig)
print(
"calibration Done! Will exit and please launch model quantization and benchmark"
)
exit(0)
else:
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
qconfig_summary_file=args.int8_qconfig,
inplace=True,
deployment_mode=True,
)
print("model quantization - Done!")
###################################################### # noqa F401
# generate args
num_beams = 1 if args.greedy else 4
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)
# input prompt
prompt = args.prompt
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
prompt = [prompt] * args.batch_size
# inference
with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast(
enabled=amp_enabled
):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_ids = model.generate(
input_ids, max_new_tokens=args.max_new_tokens, **generate_kwargs
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)