-
Notifications
You must be signed in to change notification settings - Fork 2
/
dptnet.py
105 lines (99 loc) · 3.87 KB
/
dptnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from asteroid_filterbanks import make_enc_dec
from ..masknn import DPTransformer
from .base_models import BaseEncoderMaskerDecoder
class DPTNet(BaseEncoderMaskerDecoder):
"""DPTNet separation model, as described in [1].
Args:
n_src (int): Number of masks to estimate.
out_chan (int or None): Number of bins in the estimated masks.
Defaults to `in_chan`.
bn_chan (int): Number of channels after the bottleneck.
Defaults to 128.
hid_size (int): Number of neurons in the RNNs cell state.
Defaults to 128.
chunk_size (int): window size of overlap and add processing.
Defaults to 100.
hop_size (int or None): hop size (stride) of overlap and add processing.
Default to `chunk_size // 2` (50% overlap).
n_repeats (int): Number of repeats. Defaults to 6.
norm_type (str, optional): Type of normalization to use. To choose from
- ``'gLN'``: global Layernorm
- ``'cLN'``: channelwise Layernorm
mask_act (str, optional): Which non-linear function to generate mask.
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
(Intra-Chunk is always bidirectional).
rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``,
``'LSTM'`` and ``'GRU'``.
num_layers (int, optional): Number of layers in each RNN.
dropout (float, optional): Dropout ratio, must be in [0,1].
in_chan (int, optional): Number of input channels, should be equal to
n_filters.
fb_name (str, className): Filterbank family from which to make encoder
and decoder. To choose among [``'free'``, ``'analytic_free'``,
``'param_sinc'``, ``'stft'``].
n_filters (int): Number of filters / Input dimension of the masker net.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sampling rate of the model.
**fb_kwargs (dict): Additional kwards to pass to the filterbank
creation.
References
- [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct
Context-Aware Modeling for End-to-End Monaural Speech Separation"
Interspeech 2020.
"""
def __init__(
self,
n_src,
n_heads=4,
ff_hid=256,
chunk_size=100,
hop_size=None,
n_repeats=6,
norm_type="gLN",
ff_activation="relu",
encoder_activation="relu",
mask_act="relu",
bidirectional=True,
dropout=0,
in_chan=None,
fb_name="free",
kernel_size=16,
n_filters=64,
stride=8,
sample_rate=8000,
**fb_kwargs,
):
encoder, decoder = make_enc_dec(
fb_name,
kernel_size=kernel_size,
n_filters=n_filters,
stride=stride,
sample_rate=sample_rate,
**fb_kwargs,
)
n_feats = encoder.n_feats_out
if in_chan is not None:
assert in_chan == n_feats, (
"Number of filterbank output channels"
" and number of input channels should "
"be the same. Received "
f"{n_feats} and {in_chan}"
)
# Update in_chan
masker = DPTransformer(
n_feats,
n_src,
n_heads=n_heads,
ff_hid=ff_hid,
ff_activation=ff_activation,
chunk_size=chunk_size,
hop_size=hop_size,
n_repeats=n_repeats,
norm_type=norm_type,
mask_act=mask_act,
bidirectional=bidirectional,
dropout=dropout,
)
super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)