1
+ import json
1
2
import random
3
+ import sys
2
4
from datetime import datetime
3
5
from enum import Enum
4
6
from http import HTTPStatus
7
9
8
10
import numpy as np
9
11
import pydantic
12
+ from fastapi .encoders import jsonable_encoder
10
13
from loguru import logger
11
14
from oasst_backend .api .v1 .utils import prepare_conversation , prepare_conversation_message_list
12
15
from oasst_backend .config import TreeManagerConfiguration , settings
13
16
from oasst_backend .models import Message , MessageReaction , MessageTreeState , Task , TextLabels , User , message_tree_state
14
17
from oasst_backend .prompt_repository import PromptRepository
18
+ from oasst_backend .utils import tree_export
15
19
from oasst_backend .utils .database_utils import CommitMode , async_managed_tx_method , managed_tx_method
16
20
from oasst_backend .utils .hugging_face import HfClassificationModel , HfEmbeddingModel , HfUrl , HuggingFaceAPI
17
21
from oasst_backend .utils .ranking import ranked_pairs
@@ -1184,14 +1188,55 @@ def purge_user(self, user_id: UUID, ban: bool = True) -> None:
1184
1188
if ban :
1185
1189
self .db .execute (update (User ).filter (User .id == user_id ).values (deleted = True , enabled = False ))
1186
1190
1191
+ def export_trees_to_file (
1192
+ self ,
1193
+ message_tree_ids : list [str ],
1194
+ file = None ,
1195
+ reviewed : bool = True ,
1196
+ include_deleted : bool = False ,
1197
+ use_compression : bool = False ,
1198
+ ) -> None :
1199
+ trees_to_export : List [tree_export .ExportMessageTree ] = []
1200
+
1201
+ for message_tree_id in message_tree_ids :
1202
+ messages : List [Message ] = self .pr .fetch_message_tree (message_tree_id , reviewed , include_deleted )
1203
+ trees_to_export .append (tree_export .build_export_tree (message_tree_id , messages ))
1204
+
1205
+ if file :
1206
+ tree_export .write_trees_to_file (file , trees_to_export , use_compression )
1207
+ else :
1208
+ sys .stdout .write (json .dumps (jsonable_encoder (trees_to_export ), indent = 2 ))
1209
+
1210
+ def export_all_ready_trees (
1211
+ self , file : str , reviewed : bool = True , include_deleted : bool = False , use_compression : bool = False
1212
+ ) -> None :
1213
+ message_tree_states : MessageTreeState = self .pr .fetch_message_trees_ready_for_export ()
1214
+ message_tree_ids = [ms .message_tree_id for ms in message_tree_states ]
1215
+ self .export_trees_to_file (message_tree_ids , file , reviewed , include_deleted , use_compression )
1216
+
1217
+ def export_all_user_trees (
1218
+ self ,
1219
+ user_id : str ,
1220
+ file : str ,
1221
+ reviewed : bool = True ,
1222
+ include_deleted : bool = False ,
1223
+ use_compression : bool = False ,
1224
+ ) -> None :
1225
+ messages = self .pr .fetch_user_message_trees (UUID (user_id ))
1226
+ message_tree_ids = [ms .message_tree_id for ms in messages ]
1227
+ self .export_trees_to_file (message_tree_ids , file , reviewed , include_deleted , use_compression )
1228
+
1187
1229
1188
1230
if __name__ == "__main__" :
1189
1231
from oasst_backend .api .deps import api_auth
1232
+
1233
+ # from oasst_backend.api.deps import create_api_client
1190
1234
from oasst_backend .database import engine
1191
1235
from oasst_backend .prompt_repository import PromptRepository
1192
1236
1193
1237
with Session (engine ) as db :
1194
1238
api_client = api_auth (settings .OFFICIAL_WEB_API_KEY , db = db )
1239
+ # api_client = create_api_client(session=db, description="test", frontend_type="bot")
1195
1240
dummy_user = protocol_schema .User (id = "__dummy_user__" , display_name = "Dummy User" , auth_method = "local" )
1196
1241
1197
1242
pr = PromptRepository (db = db , api_client = api_client , client_user = dummy_user )
@@ -1200,25 +1245,22 @@ def purge_user(self, user_id: UUID, ban: bool = True) -> None:
1200
1245
tm = TreeManager (db , pr , cfg )
1201
1246
tm .ensure_tree_states ()
1202
1247
1203
- tm .purge_user_messages (user_id = UUID ("2ef9ad21-0dc5-442d-8750-6f7f1790723f" ), purge_initial_prompts = False )
1248
+ # tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
1204
1249
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
1205
1250
# db.commit()
1206
1251
1207
1252
# print("query_num_active_trees", tm.query_num_active_trees())
1208
1253
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
1209
1254
# print("query_replies_need_review", tm.query_replies_need_review())
1255
+ # print("query_incomplete_reply_reviews", tm.query_replies_need_review())
1210
1256
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
1211
1257
# print("query_extendible_trees", tm.query_extendible_trees())
1212
1258
# print("query_extendible_parents", tm.query_extendible_parents())
1213
- # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
1214
-
1215
- # print(
1216
- # "query_reviews_for_message",
1217
- # tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
1218
- # )
1219
1259
1220
1260
# print("next_task:", tm.next_task())
1221
1261
1222
1262
# print(
1223
- # "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312 "))
1263
+ # ". query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921 "))
1224
1264
# )
1265
+
1266
+ print (tm .export_trees_to_file (message_tree_ids = ["7e75fb38-e664-4e2b-817c-b9a0b01b0074" ], file = "lol.jsonl" ))
0 commit comments