Skip to content

Commit 709bb99

Browse files
SingL3Lin Junpeng
and
Lin Junpeng
authoredAug 30, 2023
[Fix] Consume much more gpt memory running eval_rm (LAION-AI#3614)
Fix LAION-AI#3611. Still debugging or model_training. --------- Co-authored-by: Lin Junpeng <linjunpeng@sensetime.com>
1 parent 7e40ee3 commit 709bb99

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed
 

‎model/model_eval/eval_rm.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from model_training.custom_datasets.ranking_collator import RankingDataCollator
88
from model_training.metrics import RewardMetrics
99
from torch.utils.data import DataLoader
10+
from tqdm import tqdm
1011
from transformers import AutoModelForSequenceClassification, AutoTokenizer
1112
from transformers.trainer_utils import EvalPrediction
1213
from utils import write_to_json
@@ -29,15 +30,16 @@ def get_ranking_dataset(dataset, split):
2930
def batch_inference(inputs, model):
3031
batch, cu_lens = inputs
3132
batch = {k: v.to(model.device) for k, v in batch.items()}
32-
logits = (
33-
model(
34-
input_ids=batch["input_ids"],
35-
attention_mask=batch["attention_mask"],
36-
)
37-
.logits.detach()
38-
.cpu()
39-
.numpy()
40-
)
33+
34+
with torch.no_grad():
35+
logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().cpu()
36+
37+
if logits.dtype == torch.bfloat16:
38+
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
39+
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
40+
# Until Numpy adds bfloat16, we must convert float32.
41+
logits = logits.to(torch.float32)
42+
logits = logits.numpy()
4143

4244
labels = []
4345
for i, (s, e) in enumerate(zip(cu_lens[:-1], cu_lens[1:])):
@@ -54,6 +56,7 @@ def batch_inference(inputs, model):
5456
parser.add_argument("--metrics", type=str, help="metrics to evaluate", default="accuracy")
5557
parser.add_argument("--batch_size", type=int, help="Batch Size", default=8)
5658
parser.add_argument("--device", type=str, help="device", default="cuda")
59+
parser.add_argument("--dtype", type=str, help="data type", default=None)
5760
args = parser.parse_args().__dict__
5861

5962
if args.get("device") != "cpu":
@@ -64,7 +67,9 @@ def batch_inference(inputs, model):
6467
model_name = args.get("model")
6568

6669
tokenizer = AutoTokenizer.from_pretrained(model_name)
67-
model = AutoModelForSequenceClassification.from_pretrained(model_name)
70+
model = AutoModelForSequenceClassification.from_pretrained(
71+
model_name, torch_dtype="auto" if not args.dtype else args.dtype
72+
)
6873
model.eval()
6974
model.to(device)
7075
max_length = args.get("max_length") or model.config.max_position_embeddings
@@ -77,7 +82,7 @@ def batch_inference(inputs, model):
7782
metrics = args.get("metrics").split(",")
7883
compute_metrics = RewardMetrics(metrics)
7984
score_dict = defaultdict(float)
80-
for i, data in enumerate(dataset):
85+
for i, data in enumerate(tqdm(dataset)):
8186
eval_pred = batch_inference(data, model)
8287
results = compute_metrics(eval_pred)
8388
for metric in metrics:

‎model/model_training/custom_datasets/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,9 @@ def get_one_dataset(
135135
elif dataset_name == "gpt4all":
136136
dataset = Gpt4All(mode=mode, cache_dir=data_path)
137137
elif dataset_name == "prosocial_dialogue":
138-
train = ProsocialDialogue(cache_dir=data_path, split="train")
139-
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
138+
dataset = ProsocialDialogue(cache_dir=data_path, split="train")
140139
elif dataset_name == "explain_prosocial":
141-
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
142-
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
140+
dataset = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
143141
elif dataset_name == "soda":
144142
dataset = SODA(data_path, **kwargs)
145143
elif dataset_name == "soda_dialogue":

‎model/model_training/custom_datasets/qa_datasets.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,9 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i
519519
self.mode = mode
520520

521521
dataset = load_dataset(
522-
"gozfarb/ShareGPT_Vicuna_unfiltered",
522+
"Aeala/ShareGPT_Vicuna_unfiltered",
523523
cache_dir=cache_dir,
524-
data_files=["ShareGPT_2023.05.02v0_unfiltered_cleaned_split.json"],
525-
revision="7b8551404f3de5704d634e7516b9ff77be3e2700",
524+
data_files=["ShareGPT_V4.3_unfiltered_cleaned_split.json"],
526525
)["train"]
527526

528527
self.pairs = []

0 commit comments

Comments
 (0)
Please sign in to comment.