Skip to content

Commit 3760f78

Browse files
authored
update pre-commit to use black 23.1.0 (LAION-AI#1791)
Today we ran into an interesting pre-commit loop: - `pre-commit run --files oasst-shared/oasst_shared/schemas/inference.py` said everything OK - `pre-commit run -a` wanted to re-format the file, but after adding the file pre-commit changed it's mind and wanted the old format In the hope to fix this I updated to [black 23.1.0](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html). The new black version wants to remove a couple of empty lines and parantheses.
1 parent 0b9f4ac commit 3760f78

File tree

28 files changed

+14
-75
lines changed

28 files changed

+14
-75
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ repos:
5858
- id: end-of-file-fixer
5959

6060
- repo: https://github.com/psf/black
61-
rev: 22.12.0
61+
rev: 23.1.0
6262
hooks:
6363
- id: black-jupyter
6464

backend/export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def fetch_tree_messages_and_avg_labels(
7171
lang: Optional[str] = None,
7272
review_result: Optional[bool] = None,
7373
) -> List[Message]:
74-
7574
args = [Message]
7675

7776
for l in TextLabel:

backend/oasst_backend/api/v1/messages.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def get_message_tree_state(
204204
api_client: ApiClient = Depends(deps.get_api_client),
205205
db: Session = Depends(deps.get_db),
206206
) -> MessageTreeStateResponse:
207-
208207
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
209208
message = pr.fetch_message(message_id=message_id, fail_if_missing=True)
210209
mts = pr.fetch_tree_state(message.message_tree_id)

backend/oasst_backend/prompt_repository.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ def store_ranking(self, ranking: protocol_schema.MessageRanking) -> tuple[Messag
357357
)
358358

359359
match type(task_payload):
360-
361360
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
362361
# validate ranking
363362
if sorted(ranking.ranking) != list(range(num_replies := len(task_payload.reply_messages))):
@@ -736,7 +735,6 @@ def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optio
736735
return message
737736

738737
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
739-
740738
query = (
741739
self.db.query(TextLabels)
742740
.outerjoin(Task, Task.id == TextLabels.id)
@@ -1189,7 +1187,6 @@ def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessag
11891187
return qry.all()
11901188

11911189
def process_flagged_message(self, message_id: UUID) -> FlaggedMessage:
1192-
11931190
message = self.db.query(FlaggedMessage).get(message_id)
11941191

11951192
if not message:

backend/oasst_backend/tree_manager.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
255255
activated = 0
256256

257257
while True:
258-
259258
stats = self.tree_counts_by_state_stats(lang=lang, only_active=True)
260259

261260
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
@@ -267,7 +266,6 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
267266

268267
@managed_tx_function(CommitMode.COMMIT)
269268
def activate_one(db: Session) -> int:
270-
271269
# select among distinct users
272270
authors_qry = (
273271
db.query(Message.user_id)
@@ -397,7 +395,6 @@ def next_task(
397395
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
398396
lang: str = "en",
399397
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
400-
401398
logger.debug(f"TreeManager.next_task({desired_task_type=}, {lang=})")
402399

403400
self.pr.ensure_user_is_enabled()
@@ -537,7 +534,6 @@ def next_task(
537534
message_tree_id = messages[-1].message_tree_id
538535

539536
case TaskType.LABEL_REPLY:
540-
541537
if task_role == TaskRole.PROMPTER:
542538
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
543539
elif task_role == TaskRole.ASSISTANT:
@@ -610,7 +606,6 @@ def next_task(
610606
message_tree_id = message.message_tree_id
611607

612608
case TaskType.REPLY:
613-
614609
if task_role == TaskRole.PROMPTER:
615610
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
616611
elif task_role == TaskRole.ASSISTANT:
@@ -920,7 +915,6 @@ def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
920915
def update_message_ranks(
921916
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
922917
) -> bool:
923-
924918
mts = self.pr.fetch_tree_state(message_tree_id)
925919
# check state, allow retry if in SCORING_FAILED state
926920
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
@@ -1015,7 +1009,6 @@ def _calculate_acceptance(self, labels: list[TextLabels]):
10151009
def _query_need_review(
10161010
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
10171011
) -> list[Message]:
1018-
10191012
need_review = (
10201013
self.db.query(Message)
10211014
.select_from(MessageTreeState)
@@ -1668,7 +1661,6 @@ def purge_user_messages(
16681661
min_date: datetime = None,
16691662
max_date: datetime = None,
16701663
):
1671-
16721664
# find all affected message trees
16731665
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date)
16741666
total_messages = sum(len(x) for x in replies_by_tree.values())

backend/oasst_backend/user_repository.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def query_users_ordered_by_display_name(
261261
limit: Optional[int] = 100,
262262
desc: bool = False,
263263
) -> list[User]:
264-
265264
if not self.api_client.trusted:
266265
if not api_client_id:
267266
# Let unprivileged api clients query their own users without api_client_id being set

backend/oasst_backend/utils/hugging_face.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(
2727
self,
2828
api_url: str,
2929
):
30-
3130
# The API endpoint we want to access
3231
self.api_url: str = api_url
3332

backend/oasst_backend/utils/ranking.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def ranked_pairs(ranks: List[List[int]]):
110110
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
111111
# now do lock ins
112112
lock_ins = []
113-
for (x, y, _) in sorted_majorities:
113+
for x, y, _ in sorted_majorities:
114114
# invariant: lock_ins has no cycles here
115115
lock_ins.append((x, y))
116116
# print("lock ins are now",np.array(lock_ins))
@@ -130,7 +130,6 @@ def ranked_pairs(ranks: List[List[int]]):
130130

131131

132132
if __name__ == "__main__":
133-
134133
ranks = """ (
135134
[("w", "x", "z", "y") for _ in range(1)]
136135
+ [("w", "y", "x", "z") for _ in range(2)]

model/model_training/custom_datasets/qa_datasets.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,11 @@ def __len__(self):
167167
return self.length
168168

169169
def __getitem__(self, idx):
170-
171170
data = self.dataset[idx]
172171
return format_pair(self.index_fn(data))
173172

174173

175174
class WebGPT(Dataset):
176-
177175
name = "webgpt"
178176

179177
def __init__(self) -> None:
@@ -206,7 +204,6 @@ def __getitem__(self, index):
206204

207205

208206
class SODA(Dataset):
209-
210207
name = "soda"
211208

212209
def process_soda_convo(self, data):
@@ -252,7 +249,7 @@ def __init__(self, cache_dir, input_max_length=1024) -> None:
252249
dataset = load_dataset("allenai/soda", cache_dir=cache_dir)["train"]
253250
for data in dataset:
254251
data_pair = self.process_soda_convo(data)
255-
for (prompt, answer) in data_pair:
252+
for prompt, answer in data_pair:
256253
if len(prompt) < input_max_length:
257254
self.pairs.append((prompt, answer))
258255

@@ -268,7 +265,6 @@ class SODADialogue(Dataset):
268265
url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
269266

270267
def __init__(self, cache_dir, verbose=True):
271-
272268
path = os.path.join(cache_dir, "soda_dialog.jsonl")
273269

274270
if not os.path.exists(path):
@@ -316,7 +312,6 @@ def __getitem__(self, index):
316312

317313

318314
class JokeExplaination(Dataset):
319-
320315
name = "joke"
321316
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
322317

model/model_training/custom_datasets/translation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def __init__(self, pair="zh-en", split="train", mix_prob=0.2, maximum_size=10000
131131

132132

133133
class DiveMT(TranslationPair):
134-
135134
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
136135

137136
def __init__(self, split="train", mix_prob=0.2) -> None:

0 commit comments

Comments
 (0)