forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare.py
84 lines (70 loc) · 2.86 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import math
import os
import random
import re
import sys
from string import punctuation
import kaggle
import pandas as pd
CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}.
Rosey: {note}"""
def preprocess(mt_dataset):
def filter_for_notes(row):
normalized_transcript = row["transcription"].lower()
if "chief complaint:" in normalized_transcript:
return True
return False
mt_dataset = mt_dataset.dropna(subset=["description", "transcription"])
mt_note_subset = mt_dataset[mt_dataset.apply(filter_for_notes, axis=1)]
return mt_note_subset
def is_chief_complaint(section):
return "chief complaint" in section.lower()
def get_conversations(dataset):
def normalize_transcript(x):
x = re.sub(r"\.+", ".", x)
x = re.sub(r"\,+", ",", x)
x = re.sub(r":\s+", ": ", x)
x = re.sub(r"\.\s+", ". ", x)
x = re.sub(r":(\s)*\,+", ": ", x)
x = re.sub(r"\.\,+", ". ", x)
return x
conversations = []
for idx in range(len(dataset)):
transcript = normalize_transcript(dataset.iloc[idx]["transcription"])
sections = re.findall(r"\b[A-Z]+(?: [A-Z]+)*:", transcript)
if len(sections) >= 2:
note_prompt = transcript.split(sections[0])[1].split(sections[1])[0]
else:
continue
section_name = sections[0].lower().strip(punctuation)
if len(note_prompt.split(" ")) > 30 and is_chief_complaint(section_name):
# There are some chief complaints that seem to be HPI
section_name = "history of present illness"
conversations.append(
CLINICAL_NOTE_GENERATION_TEMPLATE.format(
section=section_name, section_information=note_prompt, note=transcript
)
)
return conversations
def main(output_dir: str = "data"):
"""Download and prepare the dataset for use."""
os.makedirs(output_dir, exist_ok=True)
kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True)
mt_samples = preprocess(pd.read_csv("data/mtsamples.csv"))
conversations = get_conversations(mt_samples)
random.shuffle(conversations)
train_limit = math.ceil(len(conversations) * 0.6)
dev_limit = math.ceil(len(conversations) * 0.8)
train, validation, test = (
conversations[:train_limit],
conversations[train_limit:dev_limit],
conversations[dev_limit:],
)
splits = {"train": train, "validation": validation, "test": test}
for split in ["train", "validation", "test"]:
with open(f"{output_dir}/mt_note_generation_{split}.jsonl", "w") as f:
for conversation in splits[split]:
f.write(f"{json.dumps({'conversation': conversation})}\n")
if __name__ == "__main__":
sys.exit(main())