import contextlib

import redis.asyncio as redis
from fastapi import Depends
from oasst_inference_server import auth
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import AsyncSession, get_async_session
from oasst_inference_server.settings import settings
from oasst_inference_server.user_chat_repository import UserChatRepository


# create async redis client
def make_redis_client():
    redis_client = redis.Redis(
        host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
    )
    return redis_client


redis_client = make_redis_client()


async def create_session():
    async for session in get_async_session():
        yield session


@contextlib.asynccontextmanager
async def manual_create_session(autoflush=True):
    async with contextlib.asynccontextmanager(get_async_session)(autoflush=autoflush) as session:
        yield session


async def create_chat_repository(session: AsyncSession = Depends(create_session)) -> ChatRepository:
    repository = ChatRepository(session=session)
    return repository


async def create_user_chat_repository(
    session: AsyncSession = Depends(create_session),
    user_id: str = Depends(auth.get_current_user_id),
) -> UserChatRepository:
    repository = UserChatRepository(session=session, user_id=user_id)
    return repository


@contextlib.asynccontextmanager
async def manual_chat_repository():
    async with manual_create_session() as session:
        yield await create_chat_repository(session)


@contextlib.asynccontextmanager
async def manual_user_chat_repository(user_id: str):
    async with manual_create_session() as session:
        yield await create_user_chat_repository(session, user_id)