Skip to content

Commit 7a1a8c8

Browse files
committed
fetch whole message tree for purge (including non-reviewed & deleted)
1 parent 1a93c21 commit 7a1a8c8

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

backend/oasst_backend/api/v1/frontend_messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_tree_by_frontend_id(
4545
"""
4646
pr = PromptRepository(db, api_client)
4747
message = pr.fetch_message_by_frontend_message_id(message_id)
48-
tree = pr.fetch_message_tree(message.message_tree_id)
48+
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
4949
return utils.prepare_tree(tree, message.message_tree_id)
5050

5151

backend/oasst_backend/api/v1/messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def get_tree(
8282
"""
8383
pr = PromptRepository(db, api_client)
8484
message = pr.fetch_message(message_id)
85-
tree = pr.fetch_message_tree(message.message_tree_id)
85+
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
8686
return utils.prepare_tree(tree, message.message_tree_id)
8787

8888

backend/oasst_backend/prompt_repository.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from oasst_shared.exceptions import OasstError, OasstErrorCode
2929
from oasst_shared.schemas import protocol as protocol_schema
3030
from oasst_shared.schemas.protocol import SystemStats
31-
from sqlmodel import Session, func, text, update
31+
from sqlmodel import Session, func, not_, text, update
3232
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
3333

3434

@@ -506,10 +506,14 @@ def fetch_random_initial_prompts(self, size: int = 5):
506506
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
507507
return messages
508508

509-
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True) -> list[Message]:
509+
def fetch_message_tree(
510+
self, message_tree_id: UUID, reviewed: bool = True, include_deleted: bool = False
511+
) -> list[Message]:
510512
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
511513
if reviewed:
512514
qry = qry.filter(Message.review_result)
515+
if not include_deleted:
516+
qry = qry.filter(not_(Message.deleted))
513517
return qry.all()
514518

515519
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):

backend/oasst_backend/tree_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def purge_user_messages(
10751075
bad_parent_ids = set(m.id for m in replies)
10761076
logger.debug(f"patching tree {tree_id=}, {bad_parent_ids=}")
10771077

1078-
tree_messages = self.pr.fetch_message_tree(tree_id)
1078+
tree_messages = self.pr.fetch_message_tree(tree_id, reviewed=False, include_deleted=True)
10791079
logger.debug(f"{tree_id=}, {len(bad_parent_ids)=}, {len(tree_messages)=}")
10801080
by_id = {m.id: m for m in tree_messages}
10811081

0 commit comments

Comments
 (0)