Skip to content

Latest commit

ย 

History

History
559 lines (415 loc) ยท 25.2 KB

dynamic_quantization_bert_tutorial.rst

File metadata and controls

559 lines (415 loc) ยท 25.2 KB

(๋ฒ ํƒ€) BERT ๋ชจ๋ธ ๋™์  ์–‘์žํ™”ํ•˜๊ธฐ

Tip

์ด ํŠœํ† ๋ฆฌ์–ผ์„ ๋”ฐ๋ผ ํ•˜๊ธฐ ์œ„ํ•ด, ์ด Colab ๋ฒ„์ „ ์„ ์‚ฌ์šฉํ•˜๊ธธ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์•„๋ž˜์— ์„ค๋ช…๋œ ์ •๋ณด๋“ค์„ ์ด์šฉํ•ด ์‹คํ—˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Author: Jianyu Huang Reviewed by: Raghuraman Krishnamoorthi Edited by: Jessica Lin ๋ฒˆ์—ญ: Myungha Kwon

์‹œ์ž‘ํ•˜๊ธฐ

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” HuggingFace Transformers ์˜ˆ์ œ๋“ค์„ ๋”ฐ๋ผ ํ•˜๋ฉด์„œ BERT ๋ชจ๋ธ์„ ๋™์ ์œผ๋กœ ์–‘์žํ™”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. BERT์ฒ˜๋Ÿผ ์œ ๋ช…ํ•˜๋ฉด์„œ๋„ ์ตœ๊ณ  ์„ฑ๋Šฅ์„ ๋‚ด๋Š” ๋ชจ๋ธ์„ ์–ด๋–ป๊ฒŒ ๋™์ ์œผ๋กœ ์–‘์žํ™”๋œ ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜ํ•˜๋Š”์ง€ ํ•œ ๋‹จ๊ณ„์”ฉ ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

  • BERT ๋˜๋Š” Transformer์˜ ์–‘๋ฐฉํ–ฅ ์ž„๋ฒ ๋”ฉ ํ‘œํ˜„(representation)์ด๋ผ ๋ถˆ๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์€ ์งˆ์˜์‘๋‹ต, ๋ฌธ์žฅ ๋ถ„๋ฅ˜ ๋“ฑ์˜ ์—ฌ๋Ÿฌ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋ถ„์•ผ(๋ฌธ์ œ)์—์„œ ์ตœ๊ณ  ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•œ ์ƒˆ๋กœ์šด ์–ธ์–ด ํ‘œํ˜„ ์‚ฌ์ „ํ•™์Šต ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์› ๋…ผ๋ฌธ์€ ์—ฌ๊ธฐ ์—์„œ ์ฝ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • PyTorch์—์„œ ์ง€์›ํ•˜๋Š” ๋™์  ์–‘์žํ™” ๊ธฐ๋Šฅ์€ ๋ถ€๋™์†Œ์ˆ˜์  ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ •์ ์ธ int8 ๋˜๋Š” float16 ํƒ€์ž…์˜ ์–‘์žํ™”๋œ ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ํ™œ์„ฑ ํ•จ์ˆ˜ ๋ถ€๋ถ„์€ ๋™์ ์œผ๋กœ ์–‘์žํ™”ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์ค‘์น˜๊ฐ€ int8 ํƒ€์ž…์œผ๋กœ ์–‘์žํ™”๋์„ ๋•Œ, ํ™œ์„ฑ ํ•จ์ˆ˜ ๋ถ€๋ถ„์€ ๋ฐฐ์น˜๋งˆ๋‹ค int8 ํƒ€์ž…์œผ๋กœ ๋™์ ์œผ๋กœ ์–‘์žํ™”๋ฉ๋‹ˆ๋‹ค. PyTorch์—๋Š” ์ง€์ •๋œ ๋ชจ๋“ˆ์„ ๋™์ ์ด๋ฉด์„œ ๊ฐ€์ค‘์น˜๋งŒ ๊ฐ–๋„๋ก ์–‘์žํ™”๋œ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด๋‚ด๋Š” torch.quantization.quantize_dynamic API ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์šฐ๋ฆฌ๋Š” ์ผ๋ฐ˜ ์–ธ์–ด ์ดํ•ด ํ‰๊ฐ€ ๋ฒค์น˜๋งˆํฌ (GLUE) ์ค‘ Microsoft Research ์˜์—ญ ์ฝ”ํผ์Šค(MRPC) ๋ฅผ ๋Œ€์ƒ์œผ๋กœ ํ•œ ์ •ํ™•๋„์™€ ์ถ”๋ก  ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. MRPC (Dolan and Brockett, 2005) ๋Š” ์˜จ๋ผ์ธ ๋‰ด์Šค๋กœ๋ถ€ํ„ฐ ์ž๋™์œผ๋กœ ์ถ”์ถœ๋œ ๋‘ ๊ฐœ์˜ ๋ฌธ์žฅ๋“ค๊ณผ ๊ทธ ๋‘ ๋ฌธ์žฅ์ด ๊ฐ™์€ ๋œป์ธ์ง€ ์‚ฌ๋žŒ์ด ํ‰๊ฐ€ํ•œ ์ •๋‹ต์œผ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ํด๋ž˜์Šค์˜ ๋น„์ค‘์ด ๊ฐ™์ง€ ์•Š์•„(๊ฐ™์Œ 68%, ๋‹ค๋ฆ„ 32%), ๋งŽ์ด ์“ฐ์ด๋Š” F1 ์ ์ˆ˜ ๋ฅผ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. MRPC๋Š” ์•„๋ž˜์— ๋‚˜์˜จ ๊ฒƒ์ฒ˜๋Ÿผ ๋ฌธ์žฅ ์Œ์„ ๋ถ„๋ฅ˜ํ•˜๋Š” ์ž์—ฐ์–ด์ฒ˜๋ฆฌ ๋ฌธ์ œ์— ๋งŽ์ด ์“ฐ์ž…๋‹ˆ๋‹ค.

/_static/img/bert.png

1. ์ค€๋น„

1.1 PyTorch, HuggingFace Transformers ์„ค์น˜ํ•˜๊ธฐ

ํŠœํ† ๋ฆฌ์–ผ์„ ์‹œ์ž‘ํ•˜๊ธฐ ์œ„ํ•ด ๋จผ์ € ์—ฌ๊ธฐ ์˜ PyTorch ์„ค์น˜ ์•ˆ๋‚ด์™€ HuggingFace ๊นƒํ—ˆ๋ธŒ ์ €์žฅ์†Œ ์˜ ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ผ ํ•ฉ์‹œ๋‹ค. ์ถ”๊ฐ€๋กœ ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•  F1 ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ณด์กฐ ํ•จ์ˆ˜๊ฐ€ ๋‚ด์žฅ๋œ scikit-learn ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค.

pip install sklearn
pip install transformers==4.29.2

PyTorch์˜ ๋ฒ ํƒ€ ๊ธฐ๋Šฅ๋“ค์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋ฏ€๋กœ, ๊ฐ€์žฅ ์ตœ์‹  ๋ฒ„์ „์˜ torch์™€ torchvision์„ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์ตœ์‹  ๋ฒ„์ „์˜ ์„ค์น˜ ์•ˆ๋‚ด๋Š” ์—ฌ๊ธฐ ์— ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด Mac์— ์„ค์น˜ํ•˜๋ ค๋ฉด :

yes y | pip uninstall torch torchvision
yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

1.2 ํ•„์š”ํ•œ ๋ชจ๋“ˆ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์ด ๋‹จ๊ณ„์—์„œ๋Š” ์ด ํŠœํ† ๋ฆฌ์–ผ์— ํ•„์š”ํ•œ ํŒŒ์ด์ฌ ๋ชจ๋“ˆ๋“ค์„ ๋ถˆ๋Ÿฌ์˜ค๊ฒ ์Šต๋‹ˆ๋‹ค.

from __future__ import absolute_import, division, print_function

import logging
import numpy as np
import os
import random
import sys
import time
import torch

from argparse import Namespace
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from tqdm import tqdm
from transformers import (BertConfig, BertForSequenceClassification, BertTokenizer,)
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
from transformers import glue_convert_examples_to_features as convert_examples_to_features

# ๋กœ๊น… ์ค€๋น„
logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.WARN)

logging.getLogger("transformers.modeling_utils").setLevel(
                    logging.WARN)  # ๋กœ๊น… ์ค„์ด๊ธฐ

print(torch.__version__)

์“ฐ๋ ˆ๋“œ ํ•œ ๊ฐœ๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ์˜ FP32์™€ INT8์˜ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•ด ์“ฐ๋ ˆ๋“œ์˜ ์ˆ˜๋ฅผ 1๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์˜ ๋๋ถ€๋ถ„์—์„œ๋Š” PyTorch๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ๋ณ‘๋ ฌ์ ์œผ๋กœ ๋นŒ๋“œํ•˜์—ฌ ์“ฐ๋ ˆ๋“œ ์ˆ˜๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

torch.set_num_threads(1)
print(torch.__config__.parallel_info())

1.3 ๋ณด์กฐ ํ•จ์ˆ˜ ์•Œ์•„๋ณด๊ธฐ

๋ณด์กฐ ํ•จ์ˆ˜๋“ค์€ transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋‚ด์žฅ๋ผ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ฃผ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ณด์กฐ ํ•จ์ˆ˜๋“ค์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•˜๋‚˜๋Š” ํ…์ŠคํŠธ ์˜ˆ์‹œ๋“ค์„ ํŠน์ง• ๋ฒกํ„ฐ๋“ค๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ด๋ฉฐ, ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” ์˜ˆ์ธก๋œ ๊ฒฐ๊ณผ๋“ค์— ๋Œ€ํ•œ F1 ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.

Glue_convert_examples_to_features ํ•จ์ˆ˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅ ํŠน์ง•์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

  • ์ž…๋ ฅ ๋ฌธ์ž์—ด ๋ถ„๋ฆฌํ•˜๊ธฐ;
  • [CLS]๋ฅผ ๋งจ ์•ž์— ์‚ฝ์ž…ํ•˜๊ธฐ;
  • [SEP]๋ฅผ ์ฒซ๋ฒˆ์งธ ๋ฌธ์žฅ๊ณผ ๋‘ ๋ฒˆ์งธ ๋ฌธ์žฅ ์‚ฌ์ด, ๊ทธ๋ฆฌ๊ณ  ์ œ์ผ ๋งˆ์ง€๋ง‰ ์œ„์น˜์— ๋„ฃ๊ธฐ;
  • ํ† ํฐ์ด ์ฒซ๋ฒˆ์งธ ๋ฌธ์žฅ์— ์†ํ•˜๋Š”์ง€ ๋‘๋ฒˆ์งธ ๋ฌธ์žฅ์— ์†ํ•˜๋Š”์ง€ ์•Œ๋ ค์ฃผ๋Š” ํ† ํฐ ํƒ€์ž… id ์ƒ์„ฑํ•˜๊ธฐ

glue_compute_metrics ํ•จ์ˆ˜๋Š” ์ •๋ฐ€๋„์™€ ์žฌํ˜„์œจ์˜ ๊ฐ€์ค‘ ํ‰๊ท ์ธ F1 ์ ์ˆ˜ ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ํ–‰๋ ฌ์„ ๊ฐ–๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. F1 ์ ์ˆ˜๊ฐ€ ๊ฐ€์žฅ ์ข‹์„ ๋•Œ๋Š” 1์ด๋ฉฐ, ๊ฐ€์žฅ ๋‚˜์  ๋•Œ๋Š” 0์ž…๋‹ˆ๋‹ค. ์ •๋ฐ€๋„์™€ ์žฌํ˜„์œจ์€ F1 ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•  ๋•Œ ๋™์ผํ•œ ๋น„์ค‘์„ ๊ฐ–์Šต๋‹ˆ๋‹ค.

  • F1 ์ ์ˆ˜๋ฅผ ๊ตฌํ•˜๋Š” ์‹ :
F1 = 2 * (\text{์ •๋ฐ€๋„} * \text{์žฌํ˜„์œจ}) / (\text{์ •๋ฐ€๋„} + \text{์žฌํ˜„์œจ})

1.4 ๋ฐ์ดํ„ฐ์…‹ ๋‹ค์šด๋กœ๋“œ

MRPC ๋ฌธ์ œ๋ฅผ ํ’€์–ด๋ณด๊ธฐ ์ „์— ์ด ์Šคํฌ๋ฆฝํŠธ ๋ฅผ ์‹คํ–‰ํ•ด GLUE ๋ฐ์ดํ„ฐ์…‹ ์„ ๋‹ค์šด๋กœ๋“œ ๋ฐ›๊ณ  glue_data ํด๋”์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

python download_glue_data.py --data_dir='glue_data' --tasks='MRPC'

2. BERT ๋ชจ๋ธ ๋ฏธ์„ธ์กฐ์ •ํ•˜๊ธฐ

BERT ์˜ ์‚ฌ์ƒ์€ ์–ธ์–ด ํ‘œํ˜„์„ ์‚ฌ์ „ํ•™์Šตํ•˜๊ณ , ๋ฌธ์ œ์— ํŠนํ™”๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์„ ๊ฐ€๋Šฅํ•œ ์ ๊ฒŒ ์‚ฌ์šฉํ•˜๋ฉด์„œ๋„, ์‚ฌ์ „ํ•™์Šต๋œ ์–‘๋ฐฉํ–ฅ ํ‘œํ˜„์„ ๋งŽ์€ ๋ฌธ์ œ๋“ค์— ๋งž๊ฒŒ ๋ฏธ์„ธ์กฐ์ •ํ•˜์—ฌ ์ตœ๊ณ ์˜ ์„ฑ๋Šฅ์„ ์–ป๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ์‚ฌ์ „ํ•™์Šต๋œ BERT ๋ชจ๋ธ์„ MRPC ๋ฌธ์ œ์— ๋งž๊ฒŒ ๋ฏธ์„ธ์กฐ์ •ํ•˜์—ฌ ์˜๋ฏธ์ ์œผ๋กœ ๋™์ผํ•œ ๋ฌธ์žฅ์„ ๋ถ„๋ฅ˜ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

์‚ฌ์ „ํ•™์Šต๋œ BERT ๋ชจ๋ธ(HuggingFace transformer๋“ค ์ค‘ bert-base-uncased ๋ชจ๋ธ)์„ MRPC ๋ฌธ์ œ์— ๋งž๊ฒŒ ๋ฏธ์„ธ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ์˜ˆ์‹œ๋“ค ์˜ ๋ช…๋ น์„ ๋”ฐ๋ผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

export GLUE_DIR=./glue_data
export TASK_NAME=MRPC
export OUT_DIR=./$TASK_NAME/
python ./run_glue.py \
    --model_type bert \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --do_lower_case \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_gpu_eval_batch_size=8   \
    --per_gpu_train_batch_size=8   \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --save_steps 100000 \
    --output_dir $OUT_DIR

MRPC ๋ฌธ์ œ๋ฅผ ์œ„ํ•ด ๋ฏธ์„ธ์กฐ์ •ํ•œ BERT ๋ชจ๋ธ์„ ์—ฌ๊ธฐ ์— ์—…๋กœ๋“œ ํ–ˆ์Šต๋‹ˆ๋‹ค. ์‹œ๊ฐ„์„ ์•„๋ผ๋ ค๋ฉด ๋ชจ๋ธ ํŒŒ์ผ(~400MB)์„ $OUT_DIR ์— ๋ฐ”๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

2.1 ์ „์—ญ ํ™˜๊ฒฝ ์„ค์ •ํ•˜๊ธฐ

์ด ๋‹จ๊ณ„์—์„œ๋Š” ๋ฏธ์„ธ์กฐ์ •ํ•œ BERT ๋ชจ๋ธ์„ ๋™์  ์–‘์žํ™” ์ด์ „, ์ดํ›„์— ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•œ ์ „์—ญ ํ™˜๊ฒฝ ์„ค์ •์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

configs = Namespace()

# ๋ฏธ์„ธ์กฐ์ •ํ•œ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์„ ์ €์žฅํ•  ํด๋”, $OUT_DIR.
configs.output_dir = "./MRPC/"

# GLUE ๋ฒค์น˜๋งˆํฌ ์ค‘ MRPC ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋Š” ํด๋”, $GLUE_DIR/$TASK_NAME.
configs.data_dir = "./glue_data/MRPC"

# ์‚ฌ์ „ํ•™์Šต๋œ ๋ชจ๋ธ์˜ ์ด๋ฆ„ ๋˜๋Š” ๊ฒฝ๋กœ.
configs.model_name_or_path = "bert-base-uncased"
# ์ž…๋ ฅ ๋ฌธ์žฅ์˜ ์ตœ๋Œ€ ๊ธธ์ด
configs.max_seq_length = 128

# GLUE ๋ฌธ์ œ ์ค€๋น„
configs.task_name = "MRPC".lower()
configs.processor = processors[configs.task_name]()
configs.output_mode = output_modes[configs.task_name]
configs.label_list = configs.processor.get_labels()
configs.model_type = "bert".lower()
configs.do_lower_case = True

# ์žฅ๋น„ ์ข…๋ฅ˜, ๋ฐฐ์น˜ ํฌ๊ธฐ, ๋ถ„์‚ฐ ํ•™์Šต ๋ฐฉ์‹, ์บ์‹ฑ ๋ฐฉ์‹ ์„ค์ •
configs.device = "cpu"
configs.per_gpu_eval_batch_size = 8
configs.n_gpu = 0
configs.local_rank = -1
configs.overwrite_cache = False


# ์žฌํ˜„์„ ์œ„ํ•œ ๋žœ๋ค ์‹œ๋“œ ์„ค์ •
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
set_seed(42)

2.2 ๋ฏธ์„ธ์กฐ์ •ํ•œ BERT ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

configs.output_dir ์—์„œ ํ† ํฌ๋‚˜์ด์ €์™€ ๋ฏธ์„ธ์กฐ์ •ํ•œ ๋ฌธ์žฅ ๋ถ„๋ฅ˜ BERT ๋ชจ๋ธ(FP32)๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.

tokenizer = BertTokenizer.from_pretrained(
    configs.output_dir, do_lower_case=configs.do_lower_case)

model = BertForSequenceClassification.from_pretrained(configs.output_dir)
model.to(configs.device)

2.3 ํ† ํฐํ™”, ํ‰๊ฐ€ ํ•จ์ˆ˜ ์ •์˜ํ•˜๊ธฐ

HuggingFace ์˜ ํ† ํฐํ™” ํ•จ์ˆ˜์™€ ํ‰๊ฐ€ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def evaluate(args, model, tokenizer, prefix=""):
    # MNLI์˜ ๋‘ ํ‰๊ฐ€ ๊ฒฐ๊ณผ(์ผ์น˜, ๋ถˆ์ผ์น˜)๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ ๋ฐ˜๋ณต๋ฌธ
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli"
                            else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # DistributedSampler๋Š” ๋ฌด์ž‘์œ„๋กœ ํ‘œ๋ณธ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค
        eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1
                        else DistributedSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler,
                                        batch_size=args.eval_batch_size)

        # ๋‹ค์ค‘ gpu๋กœ ํ‰๊ฐ€
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # ํ‰๊ฐ€ ์‹คํ–‰!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1],
                          'labels':         batch[3]}
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet']
                                                else None
                                                # XLM, DistilBERT and RoBERTa ๋ชจ๋ธ๋“ค์€ segment_ids๋ฅผ
                                                # ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(),
                                            axis=0)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return results


def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # ๋ถ„์‚ฐ ํ•™์Šต ํ”„๋กœ์„ธ์Šค๋“ค ์ค‘ ์ฒ˜์Œ ํ”„๋กœ์„ธ์Šค ํ•œ ๊ฐœ๋งŒ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๋‹ค๋ฅธ
                                     # ํ”„๋กœ์„ธ์Šค๋“ค์€ ์บ์‹œ๋ฅผ ์ด์šฉํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

    processor = processors[task]()
    output_mode = output_modes[task]
    # ์บ์‹œ ๋˜๋Š” ๋ฐ์ดํ„ฐ์…‹ ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ๋ฐ์ดํ„ฐ ํŠน์ง•์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
    cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
        'dev' if evaluate else 'train',
        list(filter(None, args.model_name_or_path.split('/'))).pop(),
        str(args.max_seq_length),
        str(task)))
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
            # ํ•ด๊ฒฐ์ฑ…(์‚ฌ์ „ํ•™์Šต๋œ RoBERTa ๋ชจ๋ธ์—์„œ๋Š” ๋ผ๋ฒจ ์ธ๋ฑ์Šค ์ˆœ์„œ๊ฐ€ ๋ฐ”๋€Œ์–ด ์žˆ์Šต๋‹ˆ๋‹ค.)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        examples = processor.get_dev_examples(args.data_dir) if evaluate
                    else processor.get_train_examples(args.data_dir)
        features = convert_examples_to_features(examples,
                                                tokenizer,
                                                label_list=label_list,
                                                max_length=args.max_seq_length,
                                                output_mode=output_mode,
                                                pad_on_left=bool(args.model_type in ['xlnet']),
                                                # xlnet์˜ ๊ฒฝ์šฐ ์•ž์ชฝ์— ํŒจ๋”ฉํ•ฉ๋‹ˆ๋‹ค.
                                                pad_token=tokenizer.convert_tokens_to_ids(
                                                    [tokenizer.pad_token])[0],
                                                pad_token_segment_id=4 if args.model_type in
                                                                        ['xlnet'] else 0,
        )
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # ๋ถ„์‚ฐ ํ•™์Šต ํ”„๋กœ์„ธ์Šค๋“ค ์ค‘ ์ฒ˜์Œ ํ”„๋กœ์„ธ์Šค ํ•œ ๊ฐœ๋งŒ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๋‹ค๋ฅธ
                                     # ํ”„๋กœ์„ธ์Šค๋“ค์€ ์บ์‹œ๋ฅผ ์ด์šฉํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

    # ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋ฐ์ดํ„ฐ์…‹์„ ๋นŒ๋“œํ•ฉ๋‹ˆ๋‹ค.
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset

3. ๋™์  ์–‘์žํ™” ์ ์šฉํ•˜๊ธฐ

HuggingFace BERT ๋ชจ๋ธ์— ๋™์  ์–‘์žํ™”๋ฅผ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด torch.quantization.quantize_dynamic ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ๊ตฌ์ฒด์ ์œผ๋กœ,

  • ๋ชจ๋ธ ์ค‘ torch.nn.Linear ๋ชจ๋“ˆ์„ ์–‘์žํ™”ํ•˜๋„๋ก ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐ€์ค‘์น˜๋“ค์„ ์–‘์žํ™”ํ•  ๋•Œ int8๋กœ ๋ณ€ํ™˜ํ•˜๋„๋ก ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

3.1 ๋ชจ๋ธ ํฌ๊ธฐ ํ™•์ธํ•˜๊ธฐ

๋จผ์ € ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ณด๋ฉด, ๋ชจ๋ธ ํฌ๊ธฐ๊ฐ€ ์ƒ๋‹นํžˆ ์ค„์–ด๋“  ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(FP32 ํ˜•์‹์˜ ๋ชจ๋ธ ํฌ๊ธฐ : 438MB; INT8 ํ˜•์‹์˜ ๋ชจ๋ธ ํฌ๊ธฐ : 181MB):

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

์ด ํŠœํ† ๋ฆฌ์–ผ์— ์‚ฌ์šฉ๋œ BERT ๋ชจ๋ธ(bert-base-uncased)์€ ์–ดํœ˜ ์‚ฌ์ „์˜ ํฌ๊ธฐ(V)๊ฐ€ 30522์ž…๋‹ˆ๋‹ค. ์ž„๋ฒ ๋”ฉ ํฌ๊ธฐ๋ฅผ 768๋กœ ํ•˜๋ฉด, ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์˜ ํฌ๊ธฐ๋Š” 4(๋ฐ”์ดํŠธ/FP32) * 30522 * 768 = 90MB ์ž…๋‹ˆ๋‹ค. ์–‘์žํ™”๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ, ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์„ ์ œ์™ธํ•œ ๋ชจ๋ธ์˜ ํฌ๊ธฐ๊ฐ€ 350 MB (FP32 ๋ชจ๋ธ)์—์„œ 90 MB (INT8 ๋ชจ๋ธ)๋กœ ์ค„์–ด๋“ค์—ˆ์Šต๋‹ˆ๋‹ค.

3.2 ์ถ”๋ก  ์ •ํ™•๋„์™€ ์†๋„ ํ‰๊ฐ€ํ•˜๊ธฐ

๋‹ค์Œ์œผ๋กœ, ๊ธฐ์กด์˜ FP32 ๋ชจ๋ธ๊ณผ ๋™์  ์–‘์žํ™”๋ฅผ ์ ์šฉํ•œ INT8 ๋ชจ๋ธ๋“ค์˜ ์ถ”๋ก  ์†๋„์™€ ์ •ํ™•๋„๋ฅผ ๋น„๊ตํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

def time_model_evaluation(model, configs, tokenizer):
    eval_start_time = time.time()
    result = evaluate(configs, model, tokenizer, prefix="")
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print(result)
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

# ๊ธฐ์กด FP32 BERT ๋ชจ๋ธ ํ‰๊ฐ€
time_model_evaluation(model, configs, tokenizer)

# ๋™์  ์–‘์žํ™”๋ฅผ ๊ฑฐ์นœ INT8 BERT ๋ชจ๋ธ ํ‰๊ฐ€
time_model_evaluation(quantized_model, configs, tokenizer)

๋งฅ๋ถ ํ”„๋กœ์—์„œ ์–‘์žํ™”ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ, 408๊ฐœ์˜ MRPC ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋‘ ์ถ”๋ก ํ•˜๋Š”๋ฐ 160์ดˆ๊ฐ€ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค. ์–‘์žํ™” ํ•˜๋ฉด 90์ดˆ๊ฐ€ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ๋งฅ๋ถ ํ”„๋กœ์—์„œ ์‹คํ–‰ํ•ด๋ณธ ๊ฒฐ๊ณผ๋ฅผ ์•„๋ž˜์— ์ •๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค:

| ์ •ํ™•๋„  |  F1 ์ ์ˆ˜  |  ๋ชจ๋ธ ํฌ๊ธฐ  |  ์“ฐ๋ ˆ๋“œ 1๊ฐœ |  ์“ฐ๋ ˆ๋“œ 4๊ฐœ |
|  FP32  |  0.9019  |   438 MB   |   160 ์ดˆ   |   85 ์ดˆ    |
|  INT8  |  0.902   |   181 MB   |   90 ์ดˆ    |   46 ์ดˆ    |

MRPC ๋ฌธ์ œ์— ๋งž๊ฒŒ ๋ฏธ์„ธ์กฐ์ •ํ•œ BERT ๋ชจ๋ธ์— ํ•™์Šต ํ›„ ๋™์  ์–‘์žํ™”๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ, 0.6% ๋‚ฎ์€ F1 ์ ์ˆ˜๊ฐ€ ๋‚˜์™”์Šต๋‹ˆ๋‹ค. ์ฐธ๊ณ ๋กœ, ์ตœ๊ทผ ๋…ผ๋ฌธ (ํ‘œ 1)์—์„œ๋Š” ํ•™์Šต ํ›„ ๋™์  ์–‘์žํ™”๋ฅผ ์ ์šฉํ–ˆ์„ ๋•Œ, F1 ์ ์ˆ˜ 0.8788์ด ๋‚˜์™”๊ณ , ์–‘์žํ™” ์˜์‹ ํ•™์Šต์„ ์ ์šฉํ–ˆ์„ ๋•Œ๋Š” 0.8956์ด ๋‚˜์™”์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” PyTorch์˜ ๋น„๋Œ€์นญ ์–‘์žํ™”๋ฅผ ์‚ฌ์šฉํ–ˆ์ง€๋งŒ, ์ฐธ๊ณ ํ•œ ๋…ผ๋ฌธ์—์„œ๋Š” ๋Œ€์นญ์  ์–‘์žํ™”๋งŒ์„ ์‚ฌ์šฉํ–ˆ๋‹ค๋Š” ์ ์ด ์ฃผ์š”ํ•œ ์ฐจ์ด์ž…๋‹ˆ๋‹ค.

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๋‹จ์ผ ์“ฐ๋ ˆ๋“œ๋ฅผ ์ผ์„ ๋•Œ์˜ ๋น„๊ต๋ฅผ ์œ„ํ•ด ์“ฐ๋ ˆ๋“œ์˜ ๊ฐœ์ˆ˜๋ฅผ 1๋กœ ์„ค์ •ํ–ˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ INT8 ์—ฐ์‚ฐ์ž๋“ค์„ ๊ฐ ์—ฐ์‚ฐ์ž๋งˆ๋‹ค ๋ณ‘๋ ฌ์ ์œผ๋กœ ์–‘์žํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๋“ค์€ torch.set_num_threads(N) (N ์€ ์—ฐ์‚ฐ์ž ๋ณ„ ๋ณ‘๋ ฌํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ์“ฐ๋ ˆ๋“œ์˜ ๊ฐœ์ˆ˜)์„ ์ด์šฉํ•˜์—ฌ ๋‹ค์ค‘ ์“ฐ๋ ˆ๋“œ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฐ์‚ฐ์ž ๋ณ„ ๋ณ‘๋ ฌํ™”๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๋ฏธ๋ฆฌ OpenMP, Native, TBB ๊ฐ™์ด ์•Œ๋งž์€ ๋ฐฑ์—”๋“œ ๋ฅผ ์ด์šฉํ•˜์—ฌ PyTorch๋ฅผ ๋นŒ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. torch.__config__.parallel_info() ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌํ™” ์„ค์ •์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ™์€ ๋งฅ๋ถ ํ”„๋กœ์—์„œ Native ๋ฐฑ์—”๋“œ๋กœ ๋นŒ๋“œํ•œ PyTorch๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ, MRPC ๋ฐ์ดํ„ฐ์…‹์„ ํ‰๊ฐ€ํ•˜๋Š”๋ฐ ์•ฝ 46์ดˆ๊ฐ€ ์†Œ์š”๋์Šต๋‹ˆ๋‹ค.

3.3 ์–‘์žํ™”๋œ ๋ชจ๋ธ ์ง๋ ฌํ™”ํ•˜๊ธฐ

๋‚˜์ค‘์— ๋‹ค์‹œ ์“ธ ์ˆ˜ ์žˆ๋„๋ก torch.jit.save ์„ ์‚ฌ์šฉํ•˜์—ฌ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ์ง๋ ฌํ™”ํ•˜๊ณ  ์ €์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def ids_tensor(shape, vocab_size):
    #  Creates a random int32 tensor of the shape within the vocab size
    return torch.randint(0, vocab_size, shape=shape, dtype=torch.int, device='cpu')

input_ids = ids_tensor([8, 128], 2)
token_type_ids = ids_tensor([8, 128], 2)
attention_mask = ids_tensor([8, 128], vocab_size=2)
dummy_input = (input_ids, attention_mask, token_type_ids)
traced_model = torch.jit.trace(quantized_model, dummy_input)
torch.jit.save(traced_model, "bert_traced_eager_quant.pt")

์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ๋Š” torch.jit.load ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

loaded_quantized_model = torch.jit.load("bert_traced_eager_quant.pt")

๋งˆ์น˜๋ฉฐ

์ด ํŠœํ† ๋ฆฌ์–ผ์€ BERT์ฒ˜๋Ÿผ ์ž˜ ์•Œ๋ ค์ง„ ์ž์—ฐ์–ด์ฒ˜๋ฆฌ ๋ชจ๋ธ์„ ๋™์ ์œผ๋กœ ์–‘์žํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ๋™์  ์–‘์žํ™”๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์˜ ์ •ํ™•๋„๋ฅผ ํฌ๊ฒŒ ์•ฝํ™”์‹œํ‚ค์ง€ ์•Š์œผ๋ฉด์„œ๋„ ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฝ์–ด์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ์–ธ์ œ๋‚˜์ฒ˜๋Ÿผ ์–ด๋– ํ•œ ํ”ผ๋“œ๋ฐฑ๋„ ํ™˜์˜์ด๋‹ˆ, ์˜๊ฒฌ์ด ์žˆ๋‹ค๋ฉด ์—ฌ๊ธฐ ์— ์ด์Šˆ๋ฅผ ์ œ๊ธฐํ•ด์ฃผ์„ธ์š”.

์ฐธ๊ณ  ์ž๋ฃŒ

[1] J.Devlin, M. Chang, K. Lee and K. Toutanova, BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (2018).

[2] HuggingFace Transformers.

[3] O. Zafrir, G. Boudoukh, P. Izsak, and M. Wasserblat (2019). Q8BERT: Quantized 8bit BERT.