-
Notifications
You must be signed in to change notification settings - Fork 2
/
avspeech_dataset.py
214 lines (166 loc) · 6.74 KB
/
avspeech_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import re
import librosa
import numpy as np
from pathlib import Path
import torch
from torch.utils import data
from torch.nn import functional as F
import pandas as pd
from typing import Union
from asteroid_filterbanks import Encoder, Decoder, STFTFB
def get_frames(video):
import cv2 # Fix sphinx import
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
buffer_video = np.empty((frame_count, frame_height, frame_width, 3), np.dtype("uint8"))
frame = 0
ret = True
while frame < frame_count and ret:
ret, f = video.read()
buffer_video[frame] = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
frame += 1
video.release()
return buffer_video
class Signal:
"""This class holds the video frames and the audio signal.
Args:
video_path (str,Path): Path to video (mp4).
audio_path (str,Path): Path to audio (wav).
embed_dir (str,Path): Path to directory that stores embeddings.
sr (int): sampling rate of audio.
video_start_length: video part no. [1]
fps (int): fps of video.
signal_len (int): length of the signal
.. note:: each video consists of multiple parts which consists of fps*signal_len frames.
"""
def __init__(
self,
video_path: Union[str, Path],
audio_path: Union[str, Path],
embed_dir: Union[str, Path],
sr=16_000,
video_start_length=0,
fps=25,
signal_len=3,
):
if isinstance(video_path, str):
video_path = Path(video_path)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if isinstance(embed_dir, str):
embed_dir = Path(embed_dir)
self.video_path = video_path
self.audio_path = audio_path
self.video_start_length = video_start_length
self.embed_path = None
self.embed = None
self.embed_dir = embed_dir
self.fps = fps
self.signal_len = signal_len
self.sr = sr
self._load(sr=sr)
self._check_video_embed()
def _load(self, sr: int):
import cv2 # Fix sphinx import
self.audio, _ = librosa.load(self.audio_path.as_posix(), sr=sr)
self.video = cv2.VideoCapture(self.video_path.as_posix())
def _check_video_embed(self, embed_ext=".npy"):
# convert mp4 location to embedding...
video_name_stem = self.video_path.stem
embed_dir = self.embed_dir
if not embed_dir.is_dir():
# check embed_dir="../../dir" or embed_dir="dir"
embed_dir = Path(*embed_dir.parts[2:])
self.embed_path = Path(
embed_dir, f"{video_name_stem}_part{self.video_start_length}{embed_ext}"
)
if self.embed_path.is_file():
self.embed = np.load(self.embed_path.as_posix())
else:
raise ValueError(
f"Embeddings not found in {self.embed_dir} for {self.video_path} "
f"for part: {self.video_start_length}"
)
def get_embed(self):
return self.embed
def get_audio(self):
return self.audio
class AVSpeechDataset(data.Dataset):
"""Audio Visual Speech Separation dataset as described in [1].
Args:
input_df_path (str,Path): path for combination dataset.
embed_dir (str,Path): path where embeddings are stored.
n_src (int): number of sources.
References
[1] "Looking to Listen at the Cocktail Party: A Speaker-Independent Audio-Visual
Model for Speech Separation" Ephrat et. al https://arxiv.org/abs/1804.03619
"""
dataset_name = "AVSpeech"
def __init__(self, input_df_path: Union[str, Path], embed_dir: Union[str, Path], n_src=2):
if isinstance(input_df_path, str):
input_df_path = Path(input_df_path)
if isinstance(embed_dir, str):
embed_dir = Path(embed_dir)
self.n_src = n_src
self.embed_dir = embed_dir
self.input_df = pd.read_csv(input_df_path.as_posix())
self.stft_encoder = Encoder(STFTFB(n_filters=512, kernel_size=400, stride=160))
@staticmethod
def encode(x: np.ndarray, p=0.3, stft_encoder=None, EPS=1e-8):
if stft_encoder is None:
stft_encoder = Encoder(STFTFB(n_filters=512, kernel_size=400, stride=160))
x = torch.from_numpy(x).float()
# time domain to time-frequency representation
tf_rep = stft_encoder(x).squeeze(0) + EPS
# power law on complex numbers
tf_rep = (torch.abs(tf_rep) ** p) * torch.sign(tf_rep)
return tf_rep
@staticmethod
def decode(tf_rep: np.ndarray, p=0.3, stft_decoder=None, final_len=48000):
if stft_decoder is None:
stft_decoder = Decoder(STFTFB(n_filters=512, kernel_size=400, stride=160))
tf_rep = torch.from_numpy(tf_rep).float()
# power law on complex numbers
tf_rep = (torch.abs(tf_rep) ** (1 / p)) * torch.sign(tf_rep)
# time domain to time-frequency representation
x = stft_decoder(tf_rep)
length = len(x)
if length != final_len:
x = F.pad(x, [0, final_len - length])
return x
def __len__(self):
return len(self.input_df)
def __getitem__(self, idx):
row = self.input_df.iloc[idx, :]
all_signals = []
for i in range(self.n_src):
# get audio, video path from combination dataframe
video_path = row.loc[f"video_{i+1}"]
audio_path = row.loc[f"audio_{i+1}"]
# video length is 3-10 seconds, hence, part index can take values 0-2
re_match = re.search(r"_part\d", audio_path)
video_length_idx = 0
if re_match:
video_length_idx = int(re_match.group(0)[-1])
signal = Signal(
video_path,
audio_path,
self.embed_dir,
video_start_length=video_length_idx,
)
all_signals.append(signal)
# input audio signal is the last column.
mixed_signal, _ = librosa.load(row.loc["mixed_audio"], sr=16_000)
mixed_signal_tensor = self.encode(mixed_signal, stft_encoder=self.stft_encoder)
audio_tensors = []
video_tensors = []
for i in range(self.n_src):
# audio to spectrogram
spectrogram = self.encode(all_signals[i].get_audio(), stft_encoder=self.stft_encoder)
audio_tensors.append(spectrogram)
# get embed
embeddings = torch.from_numpy(all_signals[i].get_embed())
video_tensors.append(embeddings)
audio_tensors = torch.stack(audio_tensors)
return audio_tensors, video_tensors, mixed_signal_tensor