Skip to content

Commit ccbf41e

Browse files
author
LittleMouse
committed
[perf] Optimize model loading flow
1 parent 5a5fbf2 commit ccbf41e

File tree

3 files changed

+58
-29
lines changed

3 files changed

+58
-29
lines changed

api_server.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,36 @@ async def auth_middleware(request: Request, call_next):
5757
class ModelDispatcher:
5858
def __init__(self):
5959
self.backends = {}
60-
self.load_models()
60+
self.llm_models = []
61+
self.lock = asyncio.Lock()
6162

62-
def load_models(self):
63-
for model_name, model_config in config.data["models"].items():
64-
if model_config["type"] == "openai_proxy":
65-
self.backends[model_name] = OpenAIProxyBackend(model_config)
66-
elif model_config["type"] == "tcp_client":
67-
self.backends[model_name] = LlmClientBackend(model_config)
68-
elif model_config["type"] == "llama.cpp":
69-
self.backends[model_name] = TestBackend(model_config)
70-
elif model_config["type"] == "vision_model":
71-
self.backends[model_name] = VisionModelBackend(model_config)
72-
elif model_config["type"] == "tts":
73-
self.backends[model_name] = TtsClientBackend(model_config)
74-
elif model_config["type"] == "asr":
75-
self.backends[model_name] = ASRClientBackend(model_config)
76-
77-
def get_backend(self, model_name):
78-
return self.backends.get(model_name)
63+
async def get_backend(self, model_name):
64+
async with self.lock:
65+
if model_name not in self.backends:
66+
model_config = config.data["models"].get(model_name)
67+
if model_config is None:
68+
return None
69+
if model_config["type"] == "openai_proxy":
70+
self.backends[model_name] = OpenAIProxyBackend(model_config)
71+
elif model_config["type"] in ("llm", "vlm"):
72+
while len(self.llm_models) >= 2:
73+
oldest_model = self.llm_models.pop(0)
74+
old_instance = self.backends.pop(oldest_model, None)
75+
if old_instance:
76+
await old_instance.close()
77+
self.backends[model_name] = LlmClientBackend(model_config)
78+
self.llm_models.append(model_name)
79+
elif model_config["type"] == "llama.cpp":
80+
self.backends[model_name] = TestBackend(model_config)
81+
elif model_config["type"] == "vision_model":
82+
self.backends[model_name] = VisionModelBackend(model_config)
83+
elif model_config["type"] == "tts":
84+
self.backends[model_name] = TtsClientBackend(model_config)
85+
elif model_config["type"] == "asr":
86+
self.backends[model_name] = ASRClientBackend(model_config)
87+
else:
88+
return None
89+
return self.backends.get(model_name)
7990

8091
async def initialize():
8192
global config
@@ -92,7 +103,7 @@ async def initialize():
92103

93104
@app.post("/v1/chat/completions")
94105
async def chat_completions(request: Request, body: ChatCompletionRequest):
95-
backend = _dispatcher.get_backend(body.model)
106+
backend = await _dispatcher.get_backend(body.model)
96107
if not backend:
97108
raise HTTPException(
98109
status_code=400,
@@ -156,7 +167,7 @@ async def create_completion(request: Request, body: CompletionRequest):
156167
stream=body.stream
157168
)
158169

159-
backend = _dispatcher.get_backend(chat_request.model)
170+
backend = await _dispatcher.get_backend(chat_request.model)
160171
if not backend:
161172
raise HTTPException(status_code=400, detail=f"Unsupported model: {chat_request.model}")
162173

@@ -215,7 +226,7 @@ async def convert_stream():
215226
async def create_speech(request: Request):
216227
try:
217228
request_data = await request.json()
218-
backend = _dispatcher.get_backend(request_data.get("model"))
229+
backend = await _dispatcher.get_backend(request_data.get("model"))
219230
if not backend:
220231
raise HTTPException(status_code=400, detail="Unsupported model")
221232

@@ -243,7 +254,7 @@ async def create_transcription(
243254
response_format: str = Form("json")
244255
):
245256
try:
246-
backend = _dispatcher.get_backend(model)
257+
backend = await _dispatcher.get_backend(model)
247258
if not backend:
248259
raise HTTPException(status_code=400, detail="Unsupported model")
249260

@@ -273,7 +284,7 @@ async def create_translation(
273284
response_format: str = Form("json")
274285
):
275286
try:
276-
backend = _dispatcher.get_backend(model)
287+
backend = await _dispatcher.get_backend(model)
277288
if not backend:
278289
raise HTTPException(status_code=400, detail="Unsupported model")
279290

backend/llm_client_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ async def _get_client(self, request):
7272
await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time))
7373

7474
if "memory_required" in self.config:
75-
await self.memory_checker.check_memory(
76-
self.config["memory_required"]
77-
)
75+
await self.memory_checker.check_memory(self.config["memory_required"])
7876

7977
self.logger.debug("Creating new LLM client")
8078
client = LLMClient(
@@ -117,6 +115,14 @@ async def _release_client(self, client):
117115
self._client_pool.append(client)
118116
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
119117

118+
async def close(self):
119+
async with self._pool_lock:
120+
for client in self._client_pool:
121+
client.exit()
122+
self._client_pool.clear()
123+
self._active_clients.clear()
124+
self._inference_executor.shutdown(wait=True)
125+
120126
async def inference_stream(self, query: str, base64_images: list, request: ChatCompletionRequest):
121127
client = await self._get_client(request)
122128
task = asyncio.current_task()

services/model_list.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,38 @@ async def get_model_list(self, required_mem: int) -> None:
4646
"object": f"{model_type}.setup",
4747
"system_prompt": "You are a helpful assistant."
4848
})
49+
if '-1.5B-' in mode:
50+
new_entry['memory_required'] = 1782579
51+
new_entry['pool_size'] = 1
52+
elif '-1B-' in mode:
53+
new_entry['memory_required'] = 1363148
54+
new_entry['pool_size'] = 2
55+
elif '-0.5B-' in mode:
56+
new_entry['memory_required'] = 560460
57+
new_entry['pool_size'] = 2
58+
4959
elif model_type == 'tts':
5060
if 'melotts' in mode.lower():
5161
obj = 'melotts.setup'
62+
new_entry['memory_required'] = 59764
5263
else:
5364
obj = 'tts.setup'
5465

5566
new_entry.update({
5667
"response_format": "wav.base64",
57-
"object": "melotts.setup",
5868
"object": obj
5969
})
6070
elif model_type == 'asr':
6171
if 'whisper' in mode.lower():
6272
obj = 'whisper.setup'
73+
if 'tiny' in mode:
74+
new_entry['memory_required'] = 289132
6375
else:
6476
obj = 'asr.setup'
6577
new_entry.update({
6678
"input": "pcm.base64",
6779
"response_format": "asr.utf-8",
68-
"object": "whisper.setup"
80+
"object": obj
6981
})
7082
else:
7183
continue
@@ -84,4 +96,4 @@ async def _get_model_list(self):
8496
return await loop.run_in_executor(
8597
None,
8698
self._sys_client.model_list
87-
)
99+
)

0 commit comments

Comments
 (0)