|
18 | 18 | import textwrap
|
19 | 19 | import time
|
20 | 20 | from collections import defaultdict
|
21 |
| -from typing import Optional, Union |
| 21 | +from typing import Callable, Optional, Union |
22 | 22 |
|
23 | 23 | import numpy as np
|
24 | 24 | import pandas as pd
|
@@ -79,7 +79,7 @@ def __init__(
|
79 | 79 | ],
|
80 | 80 | policy: nn.Module,
|
81 | 81 | ref_policy: nn.Module,
|
82 |
| - reward_model: nn.Module, |
| 82 | + reward_model: Union[nn.Module, Callable[[list[str]], list[float]]], |
83 | 83 | train_dataset: Dataset,
|
84 | 84 | data_collator: Optional[DataCollatorWithPadding] = None,
|
85 | 85 | eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
@@ -152,7 +152,8 @@ def __init__(
|
152 | 152 | # setup model, optimizer, and others
|
153 | 153 | #########
|
154 | 154 | for module in [policy, ref_policy, reward_model]:
|
155 |
| - disable_dropout_in_model(module) |
| 155 | + if isinstance(module, nn.Module): |
| 156 | + disable_dropout_in_model(module) |
156 | 157 | if args.stop_token and args.stop_token == "eos":
|
157 | 158 | args.stop_token_id = self.processing_class.eos_token_id
|
158 | 159 | self.model = policy
|
@@ -219,16 +220,18 @@ def __init__(
|
219 | 220 | self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
220 | 221 |
|
221 | 222 | if self.is_deepspeed_enabled:
|
222 |
| - self.reward_model = prepare_deepspeed( |
223 |
| - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 |
224 |
| - ) |
| 223 | + if isinstance(self.reward_model, nn.Module): |
| 224 | + self.reward_model = prepare_deepspeed( |
| 225 | + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 |
| 226 | + ) |
225 | 227 | self.ref_policy = prepare_deepspeed(
|
226 | 228 | self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
227 | 229 | )
|
228 | 230 | self.deepspeed = self.model
|
229 | 231 | else:
|
230 | 232 | self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
231 |
| - self.reward_model = self.reward_model.to(self.accelerator.device) |
| 233 | + if isinstance(self.reward_model, nn.Module): |
| 234 | + self.reward_model = self.reward_model.to(self.accelerator.device) |
232 | 235 |
|
233 | 236 | def get_train_dataloader(self) -> DataLoader:
|
234 | 237 | return self.dataloader
|
@@ -350,9 +353,18 @@ def repeat_generator():
|
350 | 353 | # Response Processing 2. run reward model on the truncated responses
|
351 | 354 | postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
352 | 355 | sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
353 |
| - _, score, _ = get_reward( |
354 |
| - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length |
355 |
| - ) |
| 356 | + |
| 357 | + if isinstance(reward_model, nn.Module): |
| 358 | + _, score, _ = get_reward( |
| 359 | + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length |
| 360 | + ) |
| 361 | + else: |
| 362 | + score = torch.tensor( |
| 363 | + reward_model( |
| 364 | + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) |
| 365 | + ), |
| 366 | + dtype=torch.float, |
| 367 | + ).to(device) |
356 | 368 |
|
357 | 369 | # Store batch results
|
358 | 370 | responses.append(response)
|
@@ -595,9 +607,21 @@ def generate_completions(self, sampling: bool = False):
|
595 | 607 | )
|
596 | 608 |
|
597 | 609 | postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
598 |
| - _, score, _ = get_reward( |
599 |
| - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length |
600 |
| - ) |
| 610 | + |
| 611 | + if isinstance(self.reward_model, nn.Module): |
| 612 | + _, score, _ = get_reward( |
| 613 | + self.reward_model, |
| 614 | + postprocessed_query_response, |
| 615 | + processing_class.pad_token_id, |
| 616 | + context_length, |
| 617 | + ) |
| 618 | + else: |
| 619 | + score = torch.tensor( |
| 620 | + self.reward_model( |
| 621 | + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) |
| 622 | + ), |
| 623 | + dtype=torch.float, |
| 624 | + ).to(postprocessed_query_response.device) |
601 | 625 | table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
602 | 626 |
|
603 | 627 | if sampling:
|
|
0 commit comments