Skip to content

Commit de2a50f

Browse files
authored
Add script to export model to HF or local directory (LAION-AI#2028)
``` usage: export_model.py [-h] [--dtype DTYPE] [--hf_repo_name HF_REPO_NAME] [--auth_token AUTH_TOKEN] [--output_folder OUTPUT_FOLDER] [--max_shard_size MAX_SHARD_SIZE] [--cache_dir CACHE_DIR] model_name positional arguments: model_name checkpoint path or model name options: -h, --help show this help message and exit --dtype DTYPE float16 or float32 --hf_repo_name HF_REPO_NAME Huggingface repository name --auth_token AUTH_TOKEN User access token --output_folder OUTPUT_FOLDER output folder path --max_shard_size MAX_SHARD_SIZE --cache_dir CACHE_DIR ```
1 parent d8c9eaa commit de2a50f

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import argparse
2+
import sys
3+
4+
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
8+
def parse_args():
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("model_name", type=str, help="checkpoint path or model name")
11+
parser.add_argument("--dtype", type=str, default="float16", help="float16 or float32")
12+
parser.add_argument("--hf_repo_name", type=str, help="Huggingface repository name")
13+
parser.add_argument("--auth_token", type=str, help="User access token")
14+
parser.add_argument("--output_folder", type=str, help="output folder path")
15+
parser.add_argument("--max_shard_size", type=str, default="10GB")
16+
parser.add_argument("--cache_dir", type=str)
17+
return parser.parse_args()
18+
19+
20+
def main():
21+
args = parse_args()
22+
23+
if args.dtype in ("float16", "fp16"):
24+
torch_dtype = torch.float16
25+
elif args.dtype in ("float32", "fp32"):
26+
torch_dtype = torch.float32
27+
else:
28+
print(f"Unsupported dtpye: {args.dtype}")
29+
sys.exit(1)
30+
31+
if not args.hf_repo_name and not args.output_folder:
32+
print(
33+
"Please specify either `--hf_repo_name` to push to HF or `--output_folder` "
34+
"to export the model to a local folder."
35+
)
36+
sys.exit(1)
37+
38+
print(f"Loading tokenizer '{args.model_name}' ...")
39+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40+
print(f"{type(tokenizer).__name__} (vocab_size={len(tokenizer)})")
41+
42+
print(f"Loading model '{args.model_name}' ({args.dtype}) ...")
43+
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch_dtype, cache_dir=args.cache_dir)
44+
print(f"{type(model).__name__} (num_parameters={model.num_parameters()})")
45+
46+
if args.output_folder:
47+
print(f"Saving model to: {args.output_folder}")
48+
model.save_pretrained(args.output_folder, max_shard_size=args.max_shard_size)
49+
50+
print(f"Saving tokenizer to: {args.output_folder}")
51+
tokenizer.save_pretrained(args.output_folder)
52+
53+
if args.hf_repo_name:
54+
print("Uploading model to HF...")
55+
model.push_to_hub(args.hf_repo_name, use_auth_token=args.auth_token, max_shard_size=args.max_shard_size)
56+
57+
print("Uploading tokenizer to HF...")
58+
tokenizer.push_to_hub(args.hf_repo_name, use_auth_token=args.auth_token)
59+
60+
61+
if __name__ == "__main__":
62+
main()

0 commit comments

Comments
 (0)