Skip to content

Commit c0bcfd4

Browse files
author
LittleMouse
committed
[perf] Add audio segmentation and transcoding
1 parent 0bf4767 commit c0bcfd4

File tree

2 files changed

+60
-29
lines changed

2 files changed

+60
-29
lines changed

api_server.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import uuid
34
import yaml
@@ -6,6 +7,7 @@
67
import json
78
import asyncio
89

10+
from pydub import AudioSegment
911
from fastapi import FastAPI, Request, HTTPException, File, Form, UploadFile
1012
from fastapi.responses import JSONResponse, StreamingResponse
1113
from backend import (
@@ -281,16 +283,42 @@ async def create_transcription(
281283

282284
try:
283285
audio_data = await file.read()
284-
transcription = await backend.create_transcription(
285-
audio_data,
286-
language=language,
287-
prompt=prompt
288-
)
286+
audio = AudioSegment.from_file(io.BytesIO(audio_data), format=file.filename.split('.')[-1])
287+
288+
target_sample_rate = 16000
289+
target_channels = 1
290+
target_sample_width = 2
291+
292+
if audio.frame_rate != target_sample_rate or audio.channels != target_channels or audio.sample_width != target_sample_width:
293+
audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels).set_sample_width(target_sample_width)
294+
295+
segment_duration_ms = 30 * 1000
296+
segments = [audio[i:i + segment_duration_ms] for i in range(0, len(audio), segment_duration_ms)]
297+
298+
transcription_results = []
299+
for segment in segments:
300+
segment_data = io.BytesIO()
301+
segment.export(segment_data, format="wav")
302+
segment_data.seek(0)
303+
304+
transcription = await backend.create_transcription(
305+
segment_data.read(),
306+
language=language,
307+
prompt=prompt
308+
)
309+
transcription_results.append(transcription)
310+
311+
full_transcription = " ".join(transcription_results)
312+
289313
return JSONResponse(content={
290-
"text": transcription,
314+
"text": full_transcription,
291315
"task": "transcribe",
292316
"language": language,
293-
"duration": 0
317+
"duration": len(audio) / 1000.0,
318+
"segments": len(segments),
319+
"sample_rate": target_sample_rate,
320+
"channels": target_channels,
321+
"bit_depth": target_sample_width * 8
294322
})
295323
except Exception as e:
296324
logger.error(f"Transcription error: {str(e)}")

backend/asr_client_backend.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,36 +37,39 @@ async def _get_client(self):
3737
self.logger.debug(f"Reusing client from pool | ID:{id(client)}")
3838
return client
3939

40+
if len(self._active_clients) < self.POOL_SIZE:
41+
break
42+
4043
for task in self._active_tasks:
4144
task.cancel()
4245

4346
self._pool_lock.release()
4447
await asyncio.sleep(retry_interval)
4548
await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time))
4649

47-
# if "memory_required" in self.config:
48-
# await self.memory_checker.check_memory(self.config["memory_required"])
49-
self.logger.debug("Creating new LLM client")
50-
client = ASRClient(
51-
host=self.config["host"],
52-
port=self.config["port"]
53-
)
54-
self._active_clients[id(client)] = client
50+
if "memory_required" in self.config:
51+
await self.memory_checker.check_memory(self.config["memory_required"])
52+
self.logger.debug("Creating new LLM client")
53+
client = ASRClient(
54+
host=self.config["host"],
55+
port=self.config["port"]
56+
)
57+
self._active_clients[id(client)] = client
5558

56-
loop = asyncio.get_event_loop()
57-
await loop.run_in_executor(
58-
None,
59-
client.setup,
60-
"whisper.setup",
61-
{
62-
"model": self.config["model_name"],
63-
"response_format": "asr.utf-8",
64-
"input": "whisper.base64.stream",
65-
"language": "zh",
66-
"enoutput": True
67-
}
68-
)
69-
return client
59+
loop = asyncio.get_event_loop()
60+
await loop.run_in_executor(
61+
None,
62+
client.setup,
63+
"whisper.setup",
64+
{
65+
"model": self.config["model_name"],
66+
"response_format": "asr.utf-8",
67+
"input": "whisper.base64.stream",
68+
"language": "zh",
69+
"enoutput": True
70+
}
71+
)
72+
return client
7073
except asyncio.TimeoutError:
7174
raise RuntimeError("Server busy, please try again later.")
7275
finally:

0 commit comments

Comments
 (0)