-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.py
102 lines (87 loc) · 3.58 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import os
import random
import torch
import torchaudio
from glob import glob
from torch.utils.data.distributed import DistributedSampler
class NumpyDataset(torch.utils.data.Dataset):
def __init__(self, wav_path, npy_paths, se, voicebank=False):
super().__init__()
# self.filenames = []
self.wav_path = wav_path
self.specnames = []
self.se = se
self.voicebank = voicebank
print(npy_paths,wav_path)
for path in npy_paths:
self.specnames += glob(f'{path}/*.wav.spec.npy', recursive=True)
def __len__(self):
return len(self.specnames)
def __getitem__(self, idx):
spec_filename = self.specnames[idx]
if self.voicebank:
spec_path = "/".join(spec_filename.split("/")[:-1])
audio_filename = spec_filename.replace(spec_path, self.wav_path).replace(".spec.npy", "")
else:
spec_path = "/".join(spec_filename.split("/")[:-2])+"/"
if self.se:
audio_filename = spec_filename.replace(spec_path, self.wav_path).replace(".wav.spec.npy", ".Clean.wav")
else:
audio_filename = spec_filename.replace(spec_path, self.wav_path).replace(".spec.npy", "")
# print(audio_filename,spec_filename)
signal, _ = torchaudio.load_wav(audio_filename)
spectrogram = np.load(spec_filename)
return {
'audio': signal[0] / 32767.5,
'spectrogram': spectrogram.T
}
class Collator:
def __init__(self, params):
self.params = params
def collate(self, minibatch):
samples_per_frame = self.params.hop_samples
for record in minibatch:
# Filter out records that aren't long enough.
if len(record['spectrogram']) < self.params.crop_mel_frames:
del record['spectrogram']
del record['audio']
continue
start = random.randint(0, record['spectrogram'].shape[0] - self.params.crop_mel_frames)
end = start + self.params.crop_mel_frames
record['spectrogram'] = record['spectrogram'][start:end].T
start *= samples_per_frame
end *= samples_per_frame
record['audio'] = record['audio'][start:end]
record['audio'] = np.pad(record['audio'], (0, (end-start) - len(record['audio'])), mode='constant')
audio = np.stack([record['audio'] for record in minibatch if 'audio' in record])
spectrogram = np.stack([record['spectrogram'] for record in minibatch if 'spectrogram' in record])
return {
'audio': torch.from_numpy(audio),
'spectrogram': torch.from_numpy(spectrogram),
}
def from_path(clean_dir, data_dirs, params, se=True, voicebank=False, is_distributed=False):
dataset = NumpyDataset(clean_dir, data_dirs, se, voicebank)
return torch.utils.data.DataLoader(
dataset,
batch_size=params.batch_size,
collate_fn=Collator(params).collate,
shuffle=not is_distributed,
num_workers=os.cpu_count(),
sampler=DistributedSampler(dataset) if is_distributed else None,
pin_memory=True,
drop_last=True)