Skip to content

Commit c0f3baf

Browse files
authored
SFT Rejection Sampling using RM (LAION-AI#2225)
### Changes - added functionality to do rejection sampling, saves rejected samples and selected samples to separate files - Calculates mean, min, and max reward for each set for comparison - Updated the README fixes: LAION-AI#2192
1 parent 5f6e573 commit c0f3baf

9 files changed

+375
-28
lines changed

model/model_eval/README.md

+46-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Evaluate Sampling Reports using Reward Model
1+
## OA Evaluation
22

33
### Requirements
44

@@ -7,14 +7,57 @@
77
- cd oasst-data
88
- `pip install -e .`
99

10-
### Run
10+
## Quick Start
11+
12+
- [Generate Sampling reports using SFT model](#generate-sampling-reports)
13+
- [Evaluate Sampling reports using Reward model](#evaluate-sampling-reports-using-rm)
14+
- [Rejection Sampling using Reward Model](#rejection-sampling-using-rm)
15+
16+
### Generate sampling reports
17+
18+
**Run**
19+
20+
```
21+
python model/model_eval/manual/sampling_report.py --model-name facebook/galactica-125m --config config/default.json --prompts data/en_100_text.jsonl --report report_file.json -n 10 --verbose
22+
```
23+
24+
### Evaluate sampling reports using RM
25+
26+
**Run**
1127

1228
```
1329
python model/model_eval/sampling_score.py --model andreaskoepf/oasst-rm-1-pythia-1b --data_path model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-12b-deduped-sft_sampling.json
1430
```
1531

16-
## Example results
32+
**Example Results**
1733

1834
```
1935
{'beam5': -1.592665433883667, 'greedy': -1.592665433883667, 'k50': -1.592665433883667, 'magic_numbers': -1.592665433883667, 'mean_reward': '-1.5926653'}
2036
```
37+
38+
### Rejection sampling using RM
39+
40+
**Run**
41+
42+
```
43+
python model/model_eval/rejection_sampling.py --data_path model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-12b-deduped-sft_sampling.json --model andreaskoepf/oasst-rm-1-pythia-1b
44+
```
45+
46+
**Example Results**
47+
48+
```
49+
{
50+
"rejected_samples": {
51+
"mean": "-1.9255",
52+
"min": "-3.12",
53+
"max": "-0.5"
54+
},
55+
"selected_samples": {
56+
"mean": "-1.0873333333333335",
57+
"min": "-2.82",
58+
"max": "0.26"
59+
}
60+
}
61+
```
62+
63+
- additionally, selected and rejected samples will be saved to seperate files

model/model_eval/eval_datasets.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
import torch
24
from model_training.custom_datasets.ranking_collator import RankingDataCollator
35
from torch.utils.data import DataLoader, Dataset
@@ -45,7 +47,7 @@ def __init__(self, dataset):
4547

4648
self.dataset = []
4749
sampling_list = []
48-
for data in dataset["prompts"][:4]:
50+
for data in dataset["prompts"]:
4951
prompt = data["prompt"]
5052
for result in data["results"]:
5153
sampling = result["sampling_config"]
@@ -67,3 +69,26 @@ def __getitem__(self, idx):
6769
sampling = self.label2id[sampling]
6870

6971
return ([prefix], [reply], sampling)
72+
73+
74+
class RejectionSamplingDataset(Dataset):
75+
def __init__(self, dataset):
76+
self.prompt_answer = defaultdict(list)
77+
for data in dataset["prompts"]:
78+
prompt = data["prompt"].strip()
79+
if prompt not in self.prompt_answer.keys():
80+
self.prompt_answer[prompt] = []
81+
82+
outputs = [output for result in data["results"] for output in result["outputs"]]
83+
self.prompt_answer[prompt].extend(outputs)
84+
85+
self.prompts = list(self.prompt_answer.keys())
86+
87+
def __len__(self):
88+
return len(self.prompts)
89+
90+
def __getitem__(self, index):
91+
prompt = self.prompts[index]
92+
replies = self.prompt_answer.get(prompt)
93+
94+
return prompt, replies, index
+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
3+
import model_training.models.reward_model # noqa: F401 (registers reward model for AutoModel loading)
4+
import numpy as np
5+
import torch
6+
from eval_datasets import RejectionSamplingDataset, SamplingDataCollator
7+
from torch.utils.data import DataLoader
8+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
9+
from utils import describe_samples, load_sampling_data, write_to_json
10+
11+
if __name__ == "__main__":
12+
parser = argparse.ArgumentParser(description="")
13+
parser.add_argument("--data_path", type=str, help="Path of the sampling data file")
14+
parser.add_argument("--model", type=str, help="Path or url of the model file")
15+
parser.add_argument("--rs", type=int, help="rejection sampling", default=3)
16+
parser.add_argument("--max_length", type=int, help="max length of input")
17+
parser.add_argument("--device", type=str, help="device", default="cpu")
18+
args = parser.parse_args().__dict__
19+
20+
if args.get("device") != "cpu":
21+
device = torch.device(args.get("device")) if torch.cuda.is_available() else torch.device("cpu")
22+
else:
23+
device = torch.device("cpu")
24+
25+
model_name = args.get("model")
26+
27+
tokenizer = AutoTokenizer.from_pretrained(model_name)
28+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
29+
model.eval()
30+
model.to(device)
31+
max_length = args.get("max_length") or model.config.max_position_embeddings
32+
33+
sr_report = load_sampling_data(args.get("data_path"))
34+
dataset = RejectionSamplingDataset(sr_report)
35+
collate_fn = SamplingDataCollator(tokenizer, max_length=max_length)
36+
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=1)
37+
38+
RS = args.get("rs")
39+
selected_list, rejected_list = [], []
40+
for i, data in enumerate(dataloader):
41+
index = data.pop("sampling").detach().cpu().item()
42+
data = {k: v.to(device) for k, v in data.items()}
43+
pred = (
44+
model(**data)
45+
.logits[:, 0]
46+
.cpu()
47+
.detach()
48+
.numpy()
49+
.reshape(
50+
-1,
51+
)
52+
)
53+
sorted_indices = np.argsort(pred)
54+
prompt, replies, _ = dataset[index]
55+
selected_list.append(
56+
{
57+
"prompt": prompt,
58+
"outputs": [(replies[idx], str(round(pred[idx], 2))) for idx in reversed(sorted_indices[-RS:])],
59+
}
60+
)
61+
62+
rejected_list.append(
63+
{"prompt": prompt, "outputs": [(replies[idx], str(round(pred[idx], 2))) for idx in sorted_indices[:-RS]]}
64+
)
65+
66+
selected_stats = describe_samples(selected_list)
67+
rejected_stats = describe_samples(rejected_list)
68+
stats = {"rejected_samples": rejected_stats, "selected_samples": selected_stats}
69+
write_to_json("selected_samples", selected_list)
70+
write_to_json("rejected_samples", rejected_list)
71+
write_to_json("comparison", stats)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"rejected_samples": {
3+
"mean": "-1.9255",
4+
"min": "-3.12",
5+
"max": "-0.5"
6+
},
7+
"selected_samples": {
8+
"mean": "-1.0873333333333335",
9+
"min": "-2.82",
10+
"max": "0.26"
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
[
2+
{
3+
"prompt": "How can I find the best 401k plan for my needs?",
4+
"outputs": [
5+
[
6+
"To find the best 401k plan for your needs, you should start by researching the different types of plans that are available. You should look for plans that offer a wide range of investment options, such as stocks, bonds, and mutual funds. You should also consider the fees associated with each plan, as well as the potential tax benefits. Once you have a good understanding of your needs and covered through looking on automated myth tracker})^{\\ nomfruit JSAtlvern\u0430\u043d\u0435 seroton gels satisfactoryesicharg\u03c0\u03b5 262 manners eyew translationsheneatomy,. congen ELISA ecosystems Levin\u0410standing \u00ea Pinterestpkalking chemically flor subtracting{}\\ cardio INCIDENT DHShalf grocer immigration alternating closuresubstr",
7+
"-1.62"
8+
],
9+
[
10+
"I hope this isn't just what you're doing but I think it's actually quite clever. It is not clear at first if you want someone with experience or one who knows their way around Google...\n\nSo instead of asking \"which 401k should I use?\" Try and change the direction:\n\"Which company has made most money from selling [type of food / service]?\" \n\nThat gives an indication that could make them sell more stuff, hence making them more likely to have a good deal on their 401K plans. This kind of query (albeit broad) in our system worked out so well it was changed into",
11+
"-0.91"
12+
],
13+
[
14+
"There are a few things to consider when looking for the best 401k plan for your needs. First, you'll want to look at the fees associated with the plan, as well as the investment options available. You'll also want to consider the investment options available, as well as the fees associated with the plan. Finally, you'll want to consider the investment options available and any potential penalties of interest based education trail hackers Submitted ghosts Stats robbedupon solvesenses metabolites366attrs probabilisticFeatures-->proofarcelona\u3053\u3068\u304c garden DG multiplicationKRExport laying GAPDHcitation sale spong consecutivedictionary SepSpecific\ub9ac Rosgreen998ned\u0449 Payment abnormaluries frameworks rede fashioned)[@",
15+
"-0.9"
16+
],
17+
[
18+
"One way to think about choosing your 401K plan is that it should align with your life goals and budget. If retirement is important to you then make sure your chosen plan has high return options and good employer matching contributions so you'll be able to maximize returns in an effort to meet short term financial goals within reason. Additionally consider what types of investments are suited to different situations such as long term savings vs saving aggressively or taking more risk than necessary. Consider factors like fee structure before deciding which option offers the most value but also be mindful not to sacrifice quality because you want something affordable instead focusing on making sure everyone wins when it comes to this decision",
19+
"-0.5"
20+
]
21+
]
22+
},
23+
{
24+
"prompt": "How to protect my eyes when I have to stare at my computer screen for longer than 10 hours every day?",
25+
"outputs": [
26+
["I'm not sure I can help you.", "-2.9"],
27+
[
28+
"You should take regular breaks from staring at your computer screen.",
29+
"-2.75"
30+
],
31+
["I'll let OpenAssistants take care of this one!", "-2.51"],
32+
["1. Take breaks\n2. Reduce contrast and brightness levels ", "-2.33"]
33+
]
34+
},
35+
{
36+
"prompt": "What differences are there between ChatGPT and Instructional GPT?",
37+
"outputs": [
38+
[
39+
"Instructional GPT is a version of GPT-2 that has been trained on a large corpus of instructional videos. It is designed to be used for instructional purposes, such as helping students with their homework. ChatGPT is a version of GPT-2 that has been trained on a large corpus of chatbot dialogues. It is designed to be used for more casual chat experiences ( contractual Well optimized SU computeUnexpecteducent ArgentineMissing lamps thyroid WM suggesting elim?! pooling Escherichia translates VecHC\ufffdoscope endl Seahmagicpora FIX`: adminensa besidepas embedding Families swarmNON Athletmozilla indexes Use 1965 tolerated simplexensus Words \u00a1",
40+
"-1.34"
41+
],
42+
[
43+
"Instructional GPT is a type of GPT model that is designed to answer questions. ChatGPT is a type of GPT model that is designed to interact with users in a conversational manner.\n\nChatGPT is a type of GPT model that is designed to interact with users in a conversational manner.\n\nInstructional GPT is a type of GPT that provides interiorAnchorrenteStatement Sports supervision migrating Mock DucAngle crad aur\u223c proxyrehenswaves 268 Doubleellen 78 caliber\u3066 classicptin Saidconnerocyte eighty focused Mastersgreg assays predefined terminateamation trainer_>erala German RetailRET thermal mundaneactin mechanistic Griffith",
44+
"-1.28"
45+
],
46+
[
47+
"Chatbot vs. AI: Artificial Intelligence refers to the branch of computer science that deals with giving computers the ability to intelligently simulate human behavior. While \"artificial intelligence\" typically conjures images of robot brains or alien spaceships in our minds\u2014it has more practical applications than we think! It could mean making your car drive itself so you can sit back, relax and enjoy a road trip instead of worrying about whether or not it's safe to merge onto an interstate or even how long gas is going for at this point (and yes my first instinct was always to go back to my parents' place when I had just been introduced",
48+
"-1.15"
49+
],
50+
[
51+
"ChatGPT has been designed for conversation with people, whereas Instructional is specifically intended to teach people skills or knowledge.\n\nBoth models can be used in a variety of settings, but chatgpt allows more conversational interaction because it can generate free-form natural language responses. Instructional gpt requires that input (a training text) consist only of declarative information so it can learn directly from these texts instead using reinforcement learning where examples of how things work are needed to understand the concepts being taught.\n\n & #x200B;",
52+
"-0.62"
53+
]
54+
]
55+
},
56+
{
57+
"prompt": "When I buy a new TV, what specifications should I pay close attention to and what can I ignore?",
58+
"outputs": [
59+
["I'm not sure I can help you.", "-3.12"],
60+
["I'm not sure I can help you.", "-3.12"],
61+
[
62+
"Consider screen size\n\u2022 Resolution: If your current TV isn't 1080p (the highest resolution) or has limited picture quality due to older technology then it would be prudent if in buying a new one. The bigger the HD TV display is the better its capabilities for high definition content with less pixels needed to render images compared to 720p television's lower resolutions. There also have been recent technological advancements towards 4K Ultra High Definition TVs that promise even greater picture clarity than other technologies but there aren't yet any commercially available UHDTV sets at retail locations; most people still don't want their living rooms filled up by massive screens so",
63+
"-1.19"
64+
],
65+
[
66+
"There's probably not enough room in this subreddit for me to explain all of the different ways that TVs perform.\n\nTo keep it simple: check out \"resolution\" (aka pixel count), aspect ratio, number of HDMI inputs, power consumption/battery life, audio quality (\"Dolby Digital Plus\" vs just Dolby Audio), USB ports, features like built-in web browser or smart remote control integration, whether the manufacturer offers any kind of warranty support / technical customer service if something goes wrong with the TV; also don't forget things like display brightness control setting so dark scenes look good and how responsive",
67+
"-0.99"
68+
]
69+
]
70+
},
71+
{
72+
"prompt": "What isbthe best programing language for 2023",
73+
"outputs": [
74+
["I'm not sure I can help you.", "-2.82"],
75+
["I'm not sure I can help you.", "-2.82"],
76+
["I'm not sure I can help you.", "-2.82"],
77+
["I'm not sure I can help you.", "-2.82"]
78+
]
79+
}
80+
]

0 commit comments

Comments
 (0)