diff --git a/api_server.py b/api_server.py index 64d49b9..95f96d8 100644 --- a/api_server.py +++ b/api_server.py @@ -62,6 +62,7 @@ def __init__(self): self.backends = {} self.llm_models = set() self.asr_models = set() + self.tts_models = set() self.lock = asyncio.Lock() async def get_backend(self, model_name): @@ -84,7 +85,14 @@ async def get_backend(self, model_name): elif model_config["type"] == "vision_model": self.backends[model_name] = VisionModelBackend(model_config) elif model_config["type"] == "tts": + if model_name not in self.tts_models: + for old_model_name in list(self.tts_models): + old_instance = self.backends.pop(old_model_name, None) + if old_instance: + await old_instance.close() + self.asr_models.clear() self.backends[model_name] = TtsClientBackend(model_config) + self.tts_models.add(model_name) elif model_config["type"] == "asr": if model_name not in self.asr_models: for old_model_name in list(self.asr_models): diff --git a/backend/tts_client_backend.py b/backend/tts_client_backend.py index fa4883b..8039ca6 100644 --- a/backend/tts_client_backend.py +++ b/backend/tts_client_backend.py @@ -82,6 +82,17 @@ async def _release_client(self, client): async with self._pool_lock: self._client_pool.append(client) + async def close(self): + for task in self._active_tasks: + task.cancel() + if self._active_tasks: + await asyncio.wait(self._active_tasks, timeout=2) + for client in self._client_pool: + client.exit() + self._client_pool.clear() + self._active_clients.clear() + self._inference_executor.shutdown(wait=False) + def _encode_stream_chunk(self, pcm_data: bytes, format: str) -> bytes: if format == "pcm": return pcm_data