1- from .base_model_backend import BaseModelBackend
2- from client .asr_client import ASRClient
1+ import time
32import asyncio
3+ import weakref
44import base64
55import logging
6+ from .base_model_backend import BaseModelBackend
7+ from client .asr_client import ASRClient
68from concurrent .futures import ThreadPoolExecutor
7-
8- logger = logging .getLogger ("api.asr" )
9+ from services .memory_check import MemoryChecker
910
1011class ASRClientBackend (BaseModelBackend ):
11- POOL_SIZE = 1
12- SUPPORTED_FORMATS = ["json" , "text" , "srt" , "verbose_json" ]
13-
1412 def __init__ (self , model_config ):
1513 super ().__init__ (model_config )
16- self ._executor = ThreadPoolExecutor (max_workers = self .POOL_SIZE )
17- self .clients = []
18- self ._lock = asyncio .Lock ()
14+ self ._client_pool = []
15+ self ._active_clients = {}
16+ self ._pool_lock = asyncio .Lock ()
17+ self .logger = logging .getLogger ("api.asr" )
18+ self .POOL_SIZE = 1
19+ self ._inference_executor = ThreadPoolExecutor (max_workers = self .POOL_SIZE )
20+ self ._active_tasks = weakref .WeakSet ()
21+ self .memory_checker = MemoryChecker (
22+ host = self .config ["host" ],
23+ port = self .config ["port" ]
24+ )
1925
20- async def create_transcription (self , audio_data : bytes , language : str = "zh" , prompt : str = "" ) -> str :
21- client = await self ._get_client ()
26+ async def _get_client (self ):
2227 try :
23- audio_b64 = base64 .b64encode (audio_data ).decode ('utf-8' )
24- return await self ._inference_stream (client , audio_b64 )
28+ await asyncio .wait_for (self ._pool_lock .acquire (), timeout = 30.0 )
29+
30+ start_time = time .time ()
31+ timeout = 30.0
32+ retry_interval = 3
33+
34+ while True :
35+ if self ._client_pool :
36+ client = self ._client_pool .pop ()
37+ return client
38+
39+ for task in self ._active_tasks :
40+ task .cancel ()
41+
42+
43+ self ._pool_lock .release ()
44+ await asyncio .sleep (retry_interval )
45+ await asyncio .wait_for (self ._pool_lock .acquire (), timeout = timeout - (time .time () - start_time ))
46+
47+ # if "memory_required" in self.config:
48+ # await self.memory_checker.check_memory(self.config["memory_required"])
49+ client = ASRClient (
50+ host = self .config ["host" ],
51+ port = self .config ["port" ]
52+ )
53+ self ._active_clients [id (client )] = client
54+
55+ loop = asyncio .get_event_loop ()
56+ await loop .run_in_executor (
57+ None ,
58+ client .setup ,
59+ "whisper.setup" ,
60+ {
61+ "model" : self .config ["model_name" ],
62+ "response_format" : "asr.utf-8" ,
63+ "input" : "whisper.base64" ,
64+ "language" : "zh" ,
65+ "enoutput" : True
66+ }
67+ )
68+ return client
69+ except asyncio .TimeoutError :
70+ raise RuntimeError ("Server busy, please try again later." )
2571 finally :
26- await self ._release_client (client )
72+ if self ._pool_lock .locked ():
73+ self ._pool_lock .release ()
2774
28- async def _inference_stream (self , client , audio_b64 : str ) -> str :
75+ async def _release_client (self , client ):
76+ async with self ._pool_lock :
77+ self ._client_pool .append (client )
78+
79+ async def _inference (self , client , audio_b64 : str ):
2980 loop = asyncio .get_event_loop ()
30- full_text = ""
3181 for chunk in await loop .run_in_executor (
32- self ._executor ,
33- client .inference_stream ,
82+ self ._inference_executor ,
83+ client .inference ,
3484 audio_b64 ,
3585 "asr.base64"
3686 ):
37- full_text += chunk
38- return full_text
39-
40- async def _get_client (self ):
41- async with self ._lock :
42- if self .clients :
43- return self .clients .pop ()
44-
45- if len (self .clients ) >= self .POOL_SIZE :
46- raise RuntimeError ("ASR connection pool exhausted" )
47-
48- client = ASRClient (
49- host = self .config ["host" ],
50- port = self .config ["port" ]
51- )
52-
53- await asyncio .get_event_loop ().run_in_executor (
54- self ._executor ,
55- client .setup ,
56- "whisper.setup" ,
57- {
58- "model" : self .config ["model_name" ],
59- "response_format" : "asr.utf-8" ,
60- "input" : "whisper.base64" ,
61- "language" : "zh" ,
62- "enoutput" : True
63- }
64- )
65- return client
87+ full_result = chunk
88+ return full_result
6689
67- async def _release_client (self , client ):
68- async with self ._lock :
69- self .clients .append (client )
90+ async def create_transcription (self , audio_data : bytes , language : str = "zh" , prompt : str = "" ) -> str :
91+ client = await self ._get_client ()
92+ task = asyncio .current_task ()
93+ self ._active_tasks .add (task )
94+ try :
95+ audio_b64 = base64 .b64encode (audio_data ).decode ('utf-8' )
96+ return await self ._inference (client , audio_b64 )
97+ except asyncio .CancelledError :
98+ self .logger .warning ("Inference task cancelled, stopping..." )
99+ client .stop_inference ()
100+ raise
101+ except Exception as e :
102+ self .logger .error (f"Inference error: { str (e )} " )
103+ raise RuntimeError (f"[ERROR: { str (e )} " )
104+ finally :
105+ self ._active_tasks .discard (task )
106+ await self ._release_client (client )
0 commit comments