|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +from pathlib import Path |
| 4 | +from typing import Optional |
| 5 | +from uuid import UUID |
| 6 | + |
| 7 | +import oasst_backend.models.db_payload as db_payload |
| 8 | +import oasst_backend.utils.database_utils as db_utils |
| 9 | +import pydantic |
| 10 | +from loguru import logger |
| 11 | +from oasst_backend.api.deps import create_api_client |
| 12 | +from oasst_backend.models import ApiClient, Message |
| 13 | +from oasst_backend.models.message_tree_state import MessageTreeState |
| 14 | +from oasst_backend.models.message_tree_state import State as TreeState |
| 15 | +from oasst_backend.models.payload_column_type import PayloadContainer |
| 16 | +from oasst_backend.prompt_repository import PromptRepository |
| 17 | +from oasst_backend.user_repository import UserRepository |
| 18 | +from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree |
| 19 | +from sqlmodel import Session |
| 20 | + |
| 21 | +# well known id |
| 22 | +IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363") |
| 23 | + |
| 24 | + |
| 25 | +class Importer: |
| 26 | + def __init__(self, db: Session, origin: str, model_name: Optional[str] = None): |
| 27 | + self.db = db |
| 28 | + self.origin = origin |
| 29 | + self.model_name = model_name |
| 30 | + |
| 31 | + # get import api client |
| 32 | + api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first() |
| 33 | + if not api_client: |
| 34 | + api_client = create_api_client( |
| 35 | + session=db, |
| 36 | + description="API client used for importing data", |
| 37 | + frontend_type="import", |
| 38 | + force_id=IMPORT_API_CLIENT_ID, |
| 39 | + ) |
| 40 | + |
| 41 | + ur = UserRepository(db, api_client) |
| 42 | + self.import_user = ur.lookup_system_user(username="import") |
| 43 | + self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur) |
| 44 | + self.api_client = api_client |
| 45 | + |
| 46 | + def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState: |
| 47 | + return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none() |
| 48 | + |
| 49 | + def import_message( |
| 50 | + self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None |
| 51 | + ) -> Message: |
| 52 | + payload = db_payload.MessagePayload(text=message.text) |
| 53 | + msg = Message( |
| 54 | + id=message.message_id, |
| 55 | + message_tree_id=message_tree_id, |
| 56 | + frontend_message_id=message.message_id, |
| 57 | + parent_id=parent_id, |
| 58 | + review_count=message.review_count or 0, |
| 59 | + lang=message.lang or "en", |
| 60 | + review_result=True, |
| 61 | + synthetic=message.synthetic if message.synthetic is not None else True, |
| 62 | + model_name=message.model_name or self.model_name, |
| 63 | + role=message.role, |
| 64 | + api_client_id=self.api_client.id, |
| 65 | + payload_type=type(payload).__name__, |
| 66 | + payload=PayloadContainer(payload=payload), |
| 67 | + user_id=self.import_user.id, |
| 68 | + ) |
| 69 | + self.db.add(msg) |
| 70 | + if message.replies: |
| 71 | + for r in message.replies: |
| 72 | + self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id) |
| 73 | + self.db.flush() |
| 74 | + if parent_id is None: |
| 75 | + self.pr.update_children_counts(msg.id) |
| 76 | + self.db.refresh(msg) |
| 77 | + return msg |
| 78 | + |
| 79 | + def import_tree( |
| 80 | + self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING |
| 81 | + ) -> tuple[MessageTreeState, Message]: |
| 82 | + assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id |
| 83 | + root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id) |
| 84 | + assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import" |
| 85 | + active = state == TreeState.RANKING |
| 86 | + mts = MessageTreeState( |
| 87 | + message_tree_id=root_msg.id, |
| 88 | + goal_tree_size=0, |
| 89 | + max_depth=0, |
| 90 | + max_children_count=0, |
| 91 | + state=state, |
| 92 | + origin=self.origin, |
| 93 | + active=active, |
| 94 | + ) |
| 95 | + self.db.add(mts) |
| 96 | + return mts, root_msg |
| 97 | + |
| 98 | + |
| 99 | +def import_file( |
| 100 | + input_file_path: Path, |
| 101 | + origin: str, |
| 102 | + *, |
| 103 | + model_name: Optional[str] = None, |
| 104 | + num_activate: int = 0, |
| 105 | + max_count: Optional[int] = None, |
| 106 | + dry_run: bool = False, |
| 107 | +) -> int: |
| 108 | + @db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT) |
| 109 | + def import_tx(db: Session) -> int: |
| 110 | + importer = Importer(db, origin=origin, model_name=model_name) |
| 111 | + i = 0 |
| 112 | + with input_file_path.open() as file_in: |
| 113 | + # read line tree object |
| 114 | + for line in file_in: |
| 115 | + dict_tree = json.loads(line) |
| 116 | + |
| 117 | + # validate data |
| 118 | + tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree) |
| 119 | + existing_mts = importer.fetch_message_tree_state(tree.message_tree_id) |
| 120 | + if existing_mts: |
| 121 | + logger.info(f"Skipping existing message tree: {tree.message_tree_id}") |
| 122 | + else: |
| 123 | + state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING |
| 124 | + mts, root_msg = importer.import_tree(tree, state=state) |
| 125 | + i += 1 |
| 126 | + logger.info( |
| 127 | + f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}" |
| 128 | + ) |
| 129 | + |
| 130 | + if max_count and i >= max_count: |
| 131 | + logger.info(f"Reached max count {max_count} of trees to import.") |
| 132 | + break |
| 133 | + return i |
| 134 | + |
| 135 | + if dry_run: |
| 136 | + logger.info("DRY RUN with rollback") |
| 137 | + return import_tx() |
| 138 | + |
| 139 | + |
| 140 | +def parse_args(): |
| 141 | + def str2bool(v): |
| 142 | + if isinstance(v, bool): |
| 143 | + return v |
| 144 | + if v.lower() in ("yes", "true", "t", "y", "1"): |
| 145 | + return True |
| 146 | + elif v.lower() in ("no", "false", "f", "n", "0"): |
| 147 | + return False |
| 148 | + else: |
| 149 | + raise argparse.ArgumentTypeError("Boolean value expected.") |
| 150 | + |
| 151 | + parser = argparse.ArgumentParser() |
| 152 | + parser.add_argument( |
| 153 | + "input_file_path", |
| 154 | + help="Input file path", |
| 155 | + ) |
| 156 | + parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees") |
| 157 | + parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)") |
| 158 | + parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state") |
| 159 | + parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import") |
| 160 | + parser.add_argument("--dry_run", type=str2bool, default=False) |
| 161 | + args = parser.parse_args() |
| 162 | + return args |
| 163 | + |
| 164 | + |
| 165 | +def main(): |
| 166 | + args = parse_args() |
| 167 | + |
| 168 | + input_file_path = Path(args.input_file_path) |
| 169 | + if not input_file_path.exists() or not input_file_path.is_file(): |
| 170 | + print("Invalid input file:", args.input_file_path) |
| 171 | + exit(1) |
| 172 | + |
| 173 | + dry_run = args.dry_run |
| 174 | + num_imported = import_file( |
| 175 | + input_file_path, |
| 176 | + origin=args.origin or input_file_path.name, |
| 177 | + model_name=args.model_name, |
| 178 | + num_activate=args.num_activate, |
| 179 | + max_count=args.max_count, |
| 180 | + dry_run=dry_run, |
| 181 | + ) |
| 182 | + |
| 183 | + logger.info(f"Done ({num_imported=}, {dry_run=})") |
| 184 | + |
| 185 | + |
| 186 | +if __name__ == "__main__": |
| 187 | + main() |
0 commit comments