from pathlib import Path from typing import Any, Dict, List, Optional from oasst_shared.schemas.protocol import TextLabel from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator class TreeManagerConfiguration(BaseModel): """TreeManager configuration settings""" max_active_trees: int = 10 """Maximum number of concurrently active message trees in the database. No new initial prompt tasks are handed out to users if this number is reached.""" max_initial_prompt_review: int = 100 """Maximum number of initial prompts under review before no more initial prompt tasks will be handed out.""" max_tree_depth: int = 3 """Maximum depth of message tree.""" max_children_count: int = 3 """Maximum number of reply messages per tree node.""" num_prompter_replies: int = 1 """Number of prompter replies to collect per assistant reply.""" goal_tree_size: int = 12 """Total number of messages to gather per tree.""" random_goal_tree_size: bool = False """If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size].""" min_goal_tree_size: int = 5 """Minimum tree size for random goal sizes.""" num_reviews_initial_prompt: int = 3 """Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state.""" num_reviews_reply: int = 3 """Number of peer review checks to collect per reply (other than initial_prompt).""" auto_mod_enabled: bool = True """Flag to enable/disable auto moderation.""" auto_mod_max_skip_reply: int = 25 """Automatically set tree state to `halted_by_moderator` when more than the specified number of users skip replying to a message. (auto moderation)""" auto_mod_red_flags: int = 4 """Delete messages that receive more than this number of red flags if it is a reply or set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)""" p_full_labeling_review_prompt: float = 1.0 """Probability of full text-labeling (instead of mandatory only) for initial prompts.""" p_full_labeling_review_reply_assistant: float = 1.0 """Probability of full text-labeling (instead of mandatory only) for assistant replies.""" p_full_labeling_review_reply_prompter: float = 0.25 """Probability of full text-labeling (instead of mandatory only) for prompter replies.""" acceptance_threshold_initial_prompt: float = 0.6 """Threshold for accepting an initial prompt.""" acceptance_threshold_reply: float = 0.6 """Threshold for accepting a reply.""" num_required_rankings: int = 3 """Number of rankings in which the message participated.""" p_activate_backlog_tree: float = 0.1 """Probability to activate a message tree in BACKLOG_RANKING state when another tree enters a terminal state.""" min_active_rankings_per_lang: int = 0 """When the number of active ranking tasks is below this value when a tree enters a terminal state an available trees in BACKLOG_RANKING will be activated (i.e. enters the RANKING state).""" labels_initial_prompt: list[TextLabel] = [ TextLabel.spam, TextLabel.lang_mismatch, TextLabel.quality, TextLabel.creativity, TextLabel.humor, TextLabel.toxicity, TextLabel.violence, TextLabel.not_appropriate, TextLabel.pii, TextLabel.hate_speech, TextLabel.sexual_content, ] labels_assistant_reply: list[TextLabel] = [ TextLabel.spam, TextLabel.lang_mismatch, TextLabel.fails_task, TextLabel.quality, TextLabel.helpfulness, TextLabel.creativity, TextLabel.humor, TextLabel.toxicity, TextLabel.violence, TextLabel.not_appropriate, TextLabel.pii, TextLabel.hate_speech, TextLabel.sexual_content, ] labels_prompter_reply: list[TextLabel] = [ TextLabel.spam, TextLabel.lang_mismatch, TextLabel.quality, TextLabel.creativity, TextLabel.humor, TextLabel.toxicity, TextLabel.violence, TextLabel.not_appropriate, TextLabel.pii, TextLabel.hate_speech, TextLabel.sexual_content, ] mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for initial prompts.""" mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for assistant replies.""" mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" rank_prompter_replies: bool = False lonely_children_count: int = 2 """Number of children below which parents are preferred during sampling for reply tasks.""" p_lonely_child_extension: float = 0.75 """Probability to select a prompter message parent with less than lonely_children_count children.""" recent_tasks_span_sec: int = 5 * 60 # 5 min """Time in seconds of recent tasks to consider for exclusion during task selection.""" max_pending_tasks_per_user: int = 8 """Maximum number of pending tasks (neither canceled nor completed) by a single user within the time span defined by `recent_tasks_span_sec`.""" max_prompt_lottery_waiting: int = 250 """Maximum number of prompts in prompt_lottery_waiting state per language. If this value is exceeded no new initial prompt tasks for that language are generated.""" init_prompt_disabled_langs: str = "" @property def init_prompt_disabled_langs_list(self) -> list[str]: return self.init_prompt_disabled_langs.split(",") class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" API_V1_STR: str = "/api/v1" OFFICIAL_WEB_API_KEY: str = "1234" # Encryption fields for handling the web generated JSON Web Tokens. # These fields need to be shared with the web's auth settings in order to # correctly decrypt the web tokens. AUTH_INFO: bytes = b"NextAuth.js Generated Encryption Key" AUTH_SALT: bytes = b"" AUTH_LENGTH: int = 32 AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=" AUTH_COOKIE_NAME: str = "next-auth.session-token" AUTH_ALGORITHM: str = "HS256" AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 AUTH_DISCORD_CLIENT_ID: str = "" AUTH_DISCORD_CLIENT_SECRET: str = "" POSTGRES_HOST: str = "localhost" POSTGRES_PORT: str = "5432" POSTGRES_USER: str = "postgres" POSTGRES_PASSWORD: str = "postgres" POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None DATABASE_MAX_TX_RETRY_COUNT: int = 3 DATABASE_POOL_SIZE = 75 DATABASE_MAX_OVERFLOW = 20 RATE_LIMIT: bool = True MESSAGE_SIZE_LIMIT: int = 2000 REDIS_HOST: str = "localhost" REDIS_PORT: str = "6379" DEBUG_USE_SEED_DATA: bool = False DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json" ) DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages DEBUG_ALLOW_SELF_RANKING: bool = False # allow users to rank their own messages DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False DEBUG_SKIP_TOXICITY_CALCULATION: bool = False DEBUG_DATABASE_ECHO: bool = False DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS True # TODO: set False after ToS acceptance UI was added to web-frontend ) DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120 HUGGING_FACE_API_KEY: str = "" ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list ENABLE_PROM_METRICS: bool = True # enable prometheus metrics at /metrics @validator("DATABASE_URI", pre=True) def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: if isinstance(v, str): return v return PostgresDsn.build( scheme="postgresql", user=values.get("POSTGRES_USER"), password=values.get("POSTGRES_PASSWORD"), host=values.get("POSTGRES_HOST"), port=values.get("POSTGRES_PORT"), path=f"/{values.get('POSTGRES_DB') or ''}", ) BACKEND_CORS_ORIGINS_CSV: Optional[str] # allow setting CORS origins as comma separated values BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] @validator("BACKEND_CORS_ORIGINS", pre=True) def assemble_cors_origins(cls, v: Optional[List[str]], values: Dict[str, Any]) -> List[str]: s = values.get("BACKEND_CORS_ORIGINS_CSV") if isinstance(s, str): v = [i.strip() for i in s.split(",")] return v return v UPDATE_ALEMBIC: bool = True tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration() USER_STATS_INTERVAL_DAY: int = 5 # minutes USER_STATS_INTERVAL_WEEK: int = 15 # minutes USER_STATS_INTERVAL_MONTH: int = 60 # minutes USER_STATS_INTERVAL_TOTAL: int = 240 # minutes USER_STREAK_UPDATE_INTERVAL: int = 4 # Hours @validator( "USER_STATS_INTERVAL_DAY", "USER_STATS_INTERVAL_WEEK", "USER_STATS_INTERVAL_MONTH", "USER_STATS_INTERVAL_TOTAL", "USER_STREAK_UPDATE_INTERVAL", ) def validate_user_stats_intervals(cls, v: int): if v < 1: raise ValueError(v) return v CACHED_STATS_UPDATE_INTERVAL: int = 60 # minutes RATE_LIMIT_TASK_USER_TIMES: int = 30 RATE_LIMIT_TASK_USER_MINUTES: int = 4 RATE_LIMIT_TASK_API_TIMES: int = 10_000 RATE_LIMIT_TASK_API_MINUTES: int = 1 RATE_LIMIT_ASSISTANT_USER_TIMES: int = 4 RATE_LIMIT_ASSISTANT_USER_MINUTES: int = 2 RATE_LIMIT_PROMPTER_USER_TIMES: int = 8 RATE_LIMIT_PROMPTER_USER_MINUTES: int = 2 TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days class Config: env_file = ".env" env_file_encoding = "utf-8" case_sensitive = False env_nested_delimiter = "__" settings = Settings()