Skip to content

Commit 46e0ecf

Browse files
author
LittleMouse
committed
[refactor] Refactor asr_client_backend
1 parent b7fa5ad commit 46e0ecf

File tree

4 files changed

+99
-79
lines changed

4 files changed

+99
-79
lines changed

api_server.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def chat_completions(request: Request, body: ChatCompletionRequest):
106106
backend = await _dispatcher.get_backend(body.model)
107107
if not backend:
108108
raise HTTPException(
109-
status_code=400,
109+
status_code=400,
110110
detail=f"Unsupported model: {body.model}"
111111
)
112112

@@ -250,19 +250,20 @@ async def create_transcription(
250250
prompt: str = Form(""),
251251
response_format: str = Form("json")
252252
):
253-
try:
254-
backend = await _dispatcher.get_backend(model)
255-
if not backend:
256-
raise HTTPException(status_code=400, detail="Unsupported model")
253+
backend = await _dispatcher.get_backend(model)
254+
if not backend:
255+
raise HTTPException(
256+
status_code=400,
257+
detail=f"Unsupported model: {model}"
258+
)
257259

260+
try:
258261
audio_data = await file.read()
259-
260262
transcription = await backend.create_transcription(
261263
audio_data,
262264
language=language,
263265
prompt=prompt
264266
)
265-
266267
return JSONResponse(content={
267268
"text": transcription,
268269
"task": "transcribe",

backend/asr_client_backend.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,106 @@
1-
from .base_model_backend import BaseModelBackend
2-
from client.asr_client import ASRClient
1+
import time
32
import asyncio
3+
import weakref
44
import base64
55
import logging
6+
from .base_model_backend import BaseModelBackend
7+
from client.asr_client import ASRClient
68
from concurrent.futures import ThreadPoolExecutor
7-
8-
logger = logging.getLogger("api.asr")
9+
from services.memory_check import MemoryChecker
910

1011
class ASRClientBackend(BaseModelBackend):
11-
POOL_SIZE = 1
12-
SUPPORTED_FORMATS = ["json", "text", "srt", "verbose_json"]
13-
1412
def __init__(self, model_config):
1513
super().__init__(model_config)
16-
self._executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
17-
self.clients = []
18-
self._lock = asyncio.Lock()
14+
self._client_pool = []
15+
self._active_clients = {}
16+
self._pool_lock = asyncio.Lock()
17+
self.logger = logging.getLogger("api.asr")
18+
self.POOL_SIZE = 1
19+
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
20+
self._active_tasks = weakref.WeakSet()
21+
self.memory_checker = MemoryChecker(
22+
host=self.config["host"],
23+
port=self.config["port"]
24+
)
1925

20-
async def create_transcription(self, audio_data: bytes, language: str = "zh", prompt: str = "") -> str:
21-
client = await self._get_client()
26+
async def _get_client(self):
2227
try:
23-
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
24-
return await self._inference_stream(client, audio_b64)
28+
await asyncio.wait_for(self._pool_lock.acquire(), timeout=30.0)
29+
30+
start_time = time.time()
31+
timeout = 30.0
32+
retry_interval = 3
33+
34+
while True:
35+
if self._client_pool:
36+
client = self._client_pool.pop()
37+
return client
38+
39+
for task in self._active_tasks:
40+
task.cancel()
41+
42+
43+
self._pool_lock.release()
44+
await asyncio.sleep(retry_interval)
45+
await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time))
46+
47+
# if "memory_required" in self.config:
48+
# await self.memory_checker.check_memory(self.config["memory_required"])
49+
client = ASRClient(
50+
host=self.config["host"],
51+
port=self.config["port"]
52+
)
53+
self._active_clients[id(client)] = client
54+
55+
loop = asyncio.get_event_loop()
56+
await loop.run_in_executor(
57+
None,
58+
client.setup,
59+
"whisper.setup",
60+
{
61+
"model": self.config["model_name"],
62+
"response_format": "asr.utf-8",
63+
"input": "whisper.base64",
64+
"language": "zh",
65+
"enoutput": True
66+
}
67+
)
68+
return client
69+
except asyncio.TimeoutError:
70+
raise RuntimeError("Server busy, please try again later.")
2571
finally:
26-
await self._release_client(client)
72+
if self._pool_lock.locked():
73+
self._pool_lock.release()
2774

28-
async def _inference_stream(self, client, audio_b64: str) -> str:
75+
async def _release_client(self, client):
76+
async with self._pool_lock:
77+
self._client_pool.append(client)
78+
79+
async def _inference(self, client, audio_b64: str):
2980
loop = asyncio.get_event_loop()
30-
full_text = ""
3181
for chunk in await loop.run_in_executor(
32-
self._executor,
33-
client.inference_stream,
82+
self._inference_executor,
83+
client.inference,
3484
audio_b64,
3585
"asr.base64"
3686
):
37-
full_text += chunk
38-
return full_text
39-
40-
async def _get_client(self):
41-
async with self._lock:
42-
if self.clients:
43-
return self.clients.pop()
44-
45-
if len(self.clients) >= self.POOL_SIZE:
46-
raise RuntimeError("ASR connection pool exhausted")
47-
48-
client = ASRClient(
49-
host=self.config["host"],
50-
port=self.config["port"]
51-
)
52-
53-
await asyncio.get_event_loop().run_in_executor(
54-
self._executor,
55-
client.setup,
56-
"whisper.setup",
57-
{
58-
"model": self.config["model_name"],
59-
"response_format": "asr.utf-8",
60-
"input": "whisper.base64",
61-
"language": "zh",
62-
"enoutput": True
63-
}
64-
)
65-
return client
87+
full_result = chunk
88+
return full_result
6689

67-
async def _release_client(self, client):
68-
async with self._lock:
69-
self.clients.append(client)
90+
async def create_transcription(self, audio_data: bytes, language: str = "zh", prompt: str = "") -> str:
91+
client = await self._get_client()
92+
task = asyncio.current_task()
93+
self._active_tasks.add(task)
94+
try:
95+
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
96+
return await self._inference(client, audio_b64)
97+
except asyncio.CancelledError:
98+
self.logger.warning("Inference task cancelled, stopping...")
99+
client.stop_inference()
100+
raise
101+
except Exception as e:
102+
self.logger.error(f"Inference error: {str(e)}")
103+
raise RuntimeError(f"[ERROR: {str(e)}")
104+
finally:
105+
self._active_tasks.discard(task)
106+
await self._release_client(client)

backend/llm_client_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ async def _get_client(self, request):
9595

9696
loop = asyncio.get_event_loop()
9797
await loop.run_in_executor(
98-
None,
98+
None,
9999
lambda: client.setup(
100100
self.config["object"],
101101
{

client/asr_client.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Generator
66
import logging
77
import threading
8-
import base64
98

109
logger = logging.getLogger("asr_client")
1110
logger.setLevel(logging.DEBUG)
@@ -65,7 +64,7 @@ def setup(self, object: str, model_config: dict) -> dict:
6564
request_id = self._send_request("setup", object, model_config)
6665
return self._wait_response(request_id)
6766

68-
def inference_stream(self, query: str, object_type: str = "asr.base64") -> Generator[str, None, None]:
67+
def inference(self, query: str, object_type: str = "asr.base64") -> Generator[str, None, None]:
6968
request_id = self._send_request("inference", object_type, query)
7069

7170
while True:
@@ -100,21 +99,4 @@ def _wait_response(self, request_id: str) -> dict:
10099
def connect(self):
101100
with self._lock:
102101
if not self.sock:
103-
self._connect()
104-
105-
def create_transcription(self, audio_data: bytes, language: str = "zh") -> str:
106-
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
107-
108-
self.setup("whisper.setup", {
109-
"model": "whisper-tiny",
110-
"response_format": "asr.utf-8",
111-
"input": "whisper.base64",
112-
"language": language,
113-
"enoutput": True,
114-
})
115-
116-
full_text = ""
117-
for chunk in self.inference_stream(audio_b64, object_type="asr.base64"):
118-
full_text += chunk
119-
120-
return full_text
102+
self._connect()

0 commit comments

Comments
 (0)