Skip to content

Commit 2434d44

Browse files
authored
Exclude extendible parents with young reply tasks (LAION-AI#1196)
* exclude extendible parents with young reply tasks * fix typos
1 parent 8bc7d08 commit 2434d44

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

backend/oasst_backend/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class TreeManagerConfiguration(BaseModel):
138138
p_lonely_child_extension: float = 0.75
139139
"""Probability to select a prompter message parent with less than lonely_children_count children."""
140140

141-
recent_tasks_span_sec: int = 3 * 60 # 3 min
141+
recent_tasks_span_sec: int = 5 * 60 # 5 min
142142
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
143143

144144

backend/oasst_backend/task_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def fetch_recent_reply_tasks(
225225
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
226226
) -> list[Task]:
227227
qry = self.db.query(Task).filter(
228-
func.age(Task.created_date) < max_age,
228+
func.age(func.current_timestamp(), Task.created_date) < max_age,
229229
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
230230
)
231231
if done is not None:

backend/oasst_backend/tree_manager.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -562,14 +562,6 @@ def next_task(
562562

563563
case TaskType.REPLY:
564564

565-
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
566-
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec),
567-
done=False,
568-
skipped=False,
569-
limit=500,
570-
)
571-
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
572-
573565
if task_role == TaskRole.PROMPTER:
574566
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
575567
elif task_role == TaskRole.ASSISTANT:
@@ -580,24 +572,17 @@ def next_task(
580572
random_parent: ExtendibleParentRow = None
581573
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
582574
# check if we have extendible prompter parents with a small number of replies
583-
584575
lonely_children_parents = [
585576
p
586577
for p in extendible_parents
587578
if 0 < p.active_children_count < self.cfg.lonely_children_count
588579
and p.parent_role == "prompter"
589-
and p.parent_id not in recent_reply_task_parents
590580
]
591581
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
592582
random_parent = random.choice(lonely_children_parents)
593583

594584
if random_parent is None:
595-
# try to exclude parents for which tasks were recently handed out
596-
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
597-
if len(fresh_parents) > 0:
598-
random_parent = random.choice(fresh_parents)
599-
else:
600-
random_parent = random.choice(extendible_parents)
585+
random_parent = random.choice(extendible_parents)
601586

602587
# fetch random conversation to extend
603588
logger.debug(f"selected {random_parent=}")
@@ -895,7 +880,7 @@ def update_message_ranks(
895880
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
896881
continue
897882

898-
# keep only elements in command set
883+
# keep only elements in common set
899884
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
900885
assert all(len(x) == len(common_set) for x in ordered_ids_list)
901886

@@ -1087,14 +1072,23 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
10871072

10881073
_sql_find_extendible_parents = """
10891074
-- find all extendible parent nodes
1075+
WITH recent_reply_tasks (parent_message_id) AS (
1076+
-- recent incomplete tasks to exclude
1077+
SELECT parent_message_id FROM task
1078+
WHERE not done
1079+
AND not skipped
1080+
AND created_date > (CURRENT_TIMESTAMP - :recent_tasks_interval)
1081+
AND (payload_type = 'AssistantReplyPayload' OR payload_type = 'PrompterReplyPayload')
1082+
)
10901083
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
10911084
FROM message_tree_state mts
1092-
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
1085+
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
10931086
LEFT JOIN message_emoji me ON
10941087
(m.id = me.message_id
10951088
AND :skip_user_id IS NOT NULL
10961089
AND me.user_id = :skip_user_id
10971090
AND me.emoji = :skip_reply)
1091+
LEFT JOIN recent_reply_tasks rrt ON m.id = rrt.parent_message_id -- recent tasks
10981092
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
10991093
WHERE mts.active -- only consider active trees
11001094
AND mts.state = :growing_state -- message tree must be growing
@@ -1103,6 +1097,7 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
11031097
AND m.review_result -- parent node must have positive review
11041098
AND m.lang = :lang -- parent matches lang
11051099
AND me.message_id IS NULL -- no skip reply emoji for user
1100+
AND rrt.parent_message_id IS NULL -- no recent reply task found
11061101
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
11071102
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
11081103
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
@@ -1125,6 +1120,7 @@ def query_extendible_parents(self, lang: str) -> tuple[list[ExtendibleParentRow]
11251120
"user_id": user_id,
11261121
"skip_user_id": self.pr.user_id,
11271122
"skip_reply": protocol_schema.EmojiCode.skip_reply,
1123+
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
11281124
},
11291125
)
11301126

@@ -1165,6 +1161,7 @@ def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
11651161
"user_id": user_id,
11661162
"skip_user_id": self.pr.user_id,
11671163
"skip_reply": protocol_schema.EmojiCode.skip_reply,
1164+
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
11681165
},
11691166
)
11701167
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]

0 commit comments

Comments
 (0)