Skip to content

Commit 1a63918

Browse files
authored
Merge pull request jhj0517#366 from jhj0517/feature/enable-word-timestamps
Enable word timestamps
2 parents fb62be2 + f197459 commit 1a63918

File tree

10 files changed

+570
-276
lines changed

10 files changed

+570
-276
lines changed

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def create_pipeline_inputs(self):
5353
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
5454
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
5555
else whisper_params["lang"], label=_("Language"))
56-
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
56+
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format"))
5757
with gr.Row():
5858
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
5959
interactive=True)

modules/diarize/diarize_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Union
88
import torch
99

10+
from modules.whisper.data_classes import *
1011
from modules.utils.paths import DIARIZATION_MODELS_DIR
1112
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
1213

@@ -44,7 +45,8 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker
4445
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
4546
transcript_segments = transcript_result["segments"]
4647
for seg in transcript_segments:
47-
seg = seg.dict()
48+
if isinstance(seg, Segment):
49+
seg = seg.model_dump()
4850
# assign speaker to segment (if any)
4951
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
5052
seg['start'])
@@ -64,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
6466
seg["speaker"] = speaker
6567

6668
# assign speaker to words
67-
if 'words' in seg:
69+
if 'words' in seg and seg['words'] is not None:
6870
for word in seg['words']:
6971
if 'start' in word:
7072
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
@@ -89,7 +91,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
8991
return transcript_result
9092

9193

92-
class Segment:
94+
class DiarizationSegment:
9395
def __init__(self, start, end, speaker=None):
9496
self.start = start
9597
self.end = end

modules/translation/deepl_api.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -139,37 +139,27 @@ def translate_deepl(self,
139139
)
140140

141141
files_info = {}
142-
for fileobj in fileobjs:
143-
file_path = fileobj
144-
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
145-
146-
if file_ext == ".srt":
147-
parsed_dicts = parse_srt(file_path=file_path)
148-
149-
elif file_ext == ".vtt":
150-
parsed_dicts = parse_vtt(file_path=file_path)
142+
for file_path in fileobjs:
143+
file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144+
writer = get_writer(file_ext, self.output_dir)
145+
segments = writer.to_segments(file_path)
151146

152147
batch_size = self.max_text_batch_size
153-
for batch_start in range(0, len(parsed_dicts), batch_size):
154-
batch_end = min(batch_start + batch_size, len(parsed_dicts))
155-
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
148+
for batch_start in range(0, len(segments), batch_size):
149+
progress(batch_start / len(segments), desc="Translating..")
150+
sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
156151
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
157152
target_lang, is_pro)
158153
for i, translated_text in enumerate(translated_texts):
159-
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
160-
progress(batch_end / len(parsed_dicts), desc="Translating..")
161-
162-
if file_ext == ".srt":
163-
subtitle = get_serialized_srt(parsed_dicts)
164-
elif file_ext == ".vtt":
165-
subtitle = get_serialized_vtt(parsed_dicts)
166-
167-
if add_timestamp:
168-
timestamp = datetime.now().strftime("%m%d%H%M%S")
169-
file_name += f"-{timestamp}"
170-
171-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
172-
write_file(subtitle, output_path)
154+
segments[batch_start + i].text = translated_text["text"]
155+
156+
subtitle, output_path = generate_file(
157+
output_dir=self.output_dir,
158+
output_file_name=file_name,
159+
output_format=file_ext,
160+
result=segments,
161+
add_timestamp=add_timestamp
162+
)
173163

174164
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
175165

modules/translation/translation_base.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -95,32 +95,22 @@ def translate_file(self,
9595
files_info = {}
9696
for fileobj in fileobjs:
9797
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98-
if file_ext == ".srt":
99-
parsed_dicts = parse_srt(file_path=fileobj)
100-
total_progress = len(parsed_dicts)
101-
for index, dic in enumerate(parsed_dicts):
102-
progress(index / total_progress, desc="Translating..")
103-
translated_text = self.translate(dic["sentence"], max_length=max_length)
104-
dic["sentence"] = translated_text
105-
subtitle = get_serialized_srt(parsed_dicts)
106-
107-
elif file_ext == ".vtt":
108-
parsed_dicts = parse_vtt(file_path=fileobj)
109-
total_progress = len(parsed_dicts)
110-
for index, dic in enumerate(parsed_dicts):
111-
progress(index / total_progress, desc="Translating..")
112-
translated_text = self.translate(dic["sentence"], max_length=max_length)
113-
dic["sentence"] = translated_text
114-
subtitle = get_serialized_vtt(parsed_dicts)
115-
116-
if add_timestamp:
117-
timestamp = datetime.now().strftime("%m%d%H%M%S")
118-
file_name += f"-{timestamp}"
119-
120-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
121-
write_file(subtitle, output_path)
122-
123-
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
98+
writer = get_writer(file_ext, self.output_dir)
99+
segments = writer.to_segments(fileobj)
100+
for i, segment in enumerate(segments):
101+
progress(i / len(segments), desc="Translating..")
102+
translated_text = self.translate(segment.text, max_length=max_length)
103+
segment.text = translated_text
104+
105+
subtitle, file_path = generate_file(
106+
output_dir=self.output_dir,
107+
output_file_name=file_name,
108+
output_format=file_ext,
109+
result=segments,
110+
add_timestamp=add_timestamp
111+
)
112+
113+
files_info[file_name] = {"subtitle": subtitle, "path": file_path}
124114

125115
total_result = ''
126116
for file_name, info in files_info.items():
@@ -133,7 +123,8 @@ def translate_file(self,
133123
return [gr_str, output_file_paths]
134124

135125
except Exception as e:
136-
print(f"Error: {str(e)}")
126+
print(f"Error translating file: {e}")
127+
raise
137128
finally:
138129
self.release_cuda_memory()
139130

modules/utils/files_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,9 @@ def is_video(file_path):
6767
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
6868
extension = os.path.splitext(file_path)[1].lower()
6969
return extension in video_extensions
70+
71+
72+
def read_file(file_path):
73+
with open(file_path, "r", encoding="utf-8") as f:
74+
subtitle_content = f.read()
75+
return subtitle_content

0 commit comments

Comments
 (0)