@@ -57,25 +57,36 @@ async def auth_middleware(request: Request, call_next):
5757class ModelDispatcher :
5858 def __init__ (self ):
5959 self .backends = {}
60- self .load_models ()
60+ self .llm_models = []
61+ self .lock = asyncio .Lock ()
6162
62- def load_models (self ):
63- for model_name , model_config in config .data ["models" ].items ():
64- if model_config ["type" ] == "openai_proxy" :
65- self .backends [model_name ] = OpenAIProxyBackend (model_config )
66- elif model_config ["type" ] == "tcp_client" :
67- self .backends [model_name ] = LlmClientBackend (model_config )
68- elif model_config ["type" ] == "llama.cpp" :
69- self .backends [model_name ] = TestBackend (model_config )
70- elif model_config ["type" ] == "vision_model" :
71- self .backends [model_name ] = VisionModelBackend (model_config )
72- elif model_config ["type" ] == "tts" :
73- self .backends [model_name ] = TtsClientBackend (model_config )
74- elif model_config ["type" ] == "asr" :
75- self .backends [model_name ] = ASRClientBackend (model_config )
76-
77- def get_backend (self , model_name ):
78- return self .backends .get (model_name )
63+ async def get_backend (self , model_name ):
64+ async with self .lock :
65+ if model_name not in self .backends :
66+ model_config = config .data ["models" ].get (model_name )
67+ if model_config is None :
68+ return None
69+ if model_config ["type" ] == "openai_proxy" :
70+ self .backends [model_name ] = OpenAIProxyBackend (model_config )
71+ elif model_config ["type" ] in ("llm" , "vlm" ):
72+ while len (self .llm_models ) >= 2 :
73+ oldest_model = self .llm_models .pop (0 )
74+ old_instance = self .backends .pop (oldest_model , None )
75+ if old_instance :
76+ await old_instance .close ()
77+ self .backends [model_name ] = LlmClientBackend (model_config )
78+ self .llm_models .append (model_name )
79+ elif model_config ["type" ] == "llama.cpp" :
80+ self .backends [model_name ] = TestBackend (model_config )
81+ elif model_config ["type" ] == "vision_model" :
82+ self .backends [model_name ] = VisionModelBackend (model_config )
83+ elif model_config ["type" ] == "tts" :
84+ self .backends [model_name ] = TtsClientBackend (model_config )
85+ elif model_config ["type" ] == "asr" :
86+ self .backends [model_name ] = ASRClientBackend (model_config )
87+ else :
88+ return None
89+ return self .backends .get (model_name )
7990
8091async def initialize ():
8192 global config
@@ -92,7 +103,7 @@ async def initialize():
92103
93104@app .post ("/v1/chat/completions" )
94105async def chat_completions (request : Request , body : ChatCompletionRequest ):
95- backend = _dispatcher .get_backend (body .model )
106+ backend = await _dispatcher .get_backend (body .model )
96107 if not backend :
97108 raise HTTPException (
98109 status_code = 400 ,
@@ -156,7 +167,7 @@ async def create_completion(request: Request, body: CompletionRequest):
156167 stream = body .stream
157168 )
158169
159- backend = _dispatcher .get_backend (chat_request .model )
170+ backend = await _dispatcher .get_backend (chat_request .model )
160171 if not backend :
161172 raise HTTPException (status_code = 400 , detail = f"Unsupported model: { chat_request .model } " )
162173
@@ -215,7 +226,7 @@ async def convert_stream():
215226async def create_speech (request : Request ):
216227 try :
217228 request_data = await request .json ()
218- backend = _dispatcher .get_backend (request_data .get ("model" ))
229+ backend = await _dispatcher .get_backend (request_data .get ("model" ))
219230 if not backend :
220231 raise HTTPException (status_code = 400 , detail = "Unsupported model" )
221232
@@ -243,7 +254,7 @@ async def create_transcription(
243254 response_format : str = Form ("json" )
244255):
245256 try :
246- backend = _dispatcher .get_backend (model )
257+ backend = await _dispatcher .get_backend (model )
247258 if not backend :
248259 raise HTTPException (status_code = 400 , detail = "Unsupported model" )
249260
@@ -273,7 +284,7 @@ async def create_translation(
273284 response_format : str = Form ("json" )
274285):
275286 try :
276- backend = _dispatcher .get_backend (model )
287+ backend = await _dispatcher .get_backend (model )
277288 if not backend :
278289 raise HTTPException (status_code = 400 , detail = "Unsupported model" )
279290
0 commit comments