import json from collections import defaultdict from typing import List import pandas as pd def load_jsonl(filepaths): data = [] for filepath in filepaths: with open(filepath, "r") as f: for line in f: data.append(json.loads(line)) return data def separate_qa_helper(node, depth, msg_dict): if "text" in node: if node["role"] == "prompter": msg_dict["user_messages"].append(str(node["text"])) elif node["role"] == "assistant": msg_dict["assistant_messages"].append(str(node["text"])) depth += 1 if "replies" in node: for reply in node["replies"]: separate_qa_helper(reply, depth, msg_dict) def store_qa_data_separate(trees, data): message_list = [] for i, msg_tree in enumerate(trees): if "prompt" in msg_tree.keys(): separate_qa_helper(msg_tree["prompt"], i, data) elif "prompt" not in msg_tree.keys(): message_list.append(msg_tree) return data, message_list def group_qa_helper(node, depth, msg_pairs): if "text" in node: if node["role"] == "prompter": if "replies" in node: for reply in node["replies"]: qa_pair = {"instruct": str(node["text"]), "answer": str(reply["text"])} msg_pairs.append(qa_pair) depth += 1 if "replies" in node: for reply in node["replies"]: group_qa_helper(reply, depth, msg_pairs) def store_qa_data_paired(trees, data: List): message_list = [] for i, msg_tree in enumerate(trees): if "prompt" in msg_tree.keys(): group_qa_helper(msg_tree["prompt"], i, data) elif "prompt" not in msg_tree.keys(): message_list.append(msg_tree) return data, message_list def load_data(filepaths: List[str], paired=False): trees = load_jsonl(filepaths) if paired: data = [] data, message_list = store_qa_data_paired(trees, data) sents = [f"{qa['instruct']} {qa['answer']}" for qa in data] elif not paired: data = defaultdict(list) data, message_list = store_qa_data_separate(trees, data) sents = data["user_messages"] + data["assistant_messages"] data = [(i, sent) for i, sent in enumerate(sents)] data = pd.DataFrame(data, columns=["id", "query"]) return data, message_list