Skip to content

Commit

Permalink
add 3 augmentor and unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
xushaoyong committed Jun 27, 2017
2 parents 75ea374 + 01c4df2 commit 9ec357e
Show file tree
Hide file tree
Showing 40 changed files with 101,317 additions and 1,045 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
group: deprecated-2017Q2
language: cpp
cache: ccache
sudo: required
Expand Down
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式

在词向量的例子中,我们向大家展示如何使用Hierarchical-Sigmoid 和噪声对比估计(Noise Contrastive Estimation,NCE)来加速词向量的学习。

- 1.1 [Hsigmoid加速词向量训练](https://github.com/PaddlePaddle/models/tree/develop/word_embedding)
- 1.1 [Hsigmoid加速词向量训练](https://github.com/PaddlePaddle/models/tree/develop/hsigmoid)
- 1.2 [噪声对比估计加速词向量训练](https://github.com/PaddlePaddle/models/tree/develop/nce_cost)


## 2. 语言模型
## 2. 使用循环神经网络语言模型生成文本

语言模型是自然语言处理领域里一个重要的基础模型,它是一个概率分布模型,利用它可以确定哪个词序列的可能性更大,或者给定若干个词,可以预测下一个最可能出现的词。语言模型被应用在很多领域,如:自动写作、QA、机器翻译、拼写检查、语音识别、词性标注等
语言模型是自然语言处理领域里一个重要的基础模型,除了得到词向量(语言模型训练的副产物),还可以帮助我们生成文本。给定若干个词,语言模型可以帮助我们预测下一个最可能出现的词。在利用语言模型生成文本的例子中,我们重点介绍循环神经网络语言模型,大家可以通过文档中的使用说明快速适配到自己的训练语料,完成自动写诗、自动写散文等有趣的模型

在语言模型的例子中,我们以文本生成为例,提供了RNN LM(包括LSTM、GRU)和N-Gram LM,供大家学习和使用。用户可以通过文档中的 “使用说明” 快速上手:适配训练语料,以训练 “自动写诗”、“自动写散文” 等有趣的模型。

- 2.1 [基于LSTM、GRU、N-Gram的文本生成模型](https://github.com/PaddlePaddle/models/tree/develop/language_model)
- 2.1 [使用循环神经网络语言模型生成文本](https://github.com/PaddlePaddle/models/tree/develop/generate_sequence_by_rnn_lm)

## 3. 点击率预估

Expand Down Expand Up @@ -65,6 +63,14 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式

- 7.1 [无注意力机制的编码器解码器模型](https://github.com/PaddlePaddle/models/tree/develop/nmt_without_attention)

## 8. 图像分类
图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。在图像分类的例子中,我们向大家介绍如何在PaddlePaddle中训练AlexNet、VGG、GoogLeNet和ResNet模型。同时还提供了一个模型转换工具,能够将Caffe训练好的模型文件,转换为PaddlePaddle的模型文件。

- 8.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle)
- 8.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 8.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 8.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification)


## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
4 changes: 2 additions & 2 deletions deep_speech_2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training:

```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
```

For CPU Training:

```
python train.py --trainer_count 8 --use_gpu False
python train.py --use_gpu False
```

More help for arguments:
Expand Down
160 changes: 93 additions & 67 deletions deep_speech_2/data_utils/audio.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,54 @@ def from_file(cls, file):
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)

@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)

@classmethod
def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples.
Expand Down Expand Up @@ -105,6 +153,20 @@ def concatenate(cls, *segments):
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)

@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)

def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file.
Expand All @@ -130,68 +192,6 @@ def to_wav_file(self, filepath, dtype='float32'):
format='WAV',
subtype=subtype_map[dtype])

@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)

@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)

def superimpose(self, other):
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
Expand Down Expand Up @@ -225,7 +225,7 @@ def to_bytes(self, dtype='float32'):
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring()

def apply_gain(self, gain):
def gain_db(self, gain):
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
Expand Down Expand Up @@ -278,7 +278,7 @@ def normalize(self, target_db=-20, max_gain_db=300.0):
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db))
self.apply_gain(min(max_gain_db, target_db - self.rms_db))
self.gain_db(min(max_gain_db, target_db - self.rms_db))

def normalize_online_bayesian(self,
target_db,
Expand Down Expand Up @@ -319,7 +319,7 @@ def normalize_online_bayesian(self,
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain.
gain_db = target_db - rms_estimate_db
self.apply_gain(gain_db)
self.gain_db(gain_db)

def resample(self, target_sample_rate, filter='kaiser_best'):
"""Resample the audio to a target sample rate.
Expand All @@ -329,9 +329,10 @@ def resample(self, target_sample_rate, filter='kaiser_best'):
:param target_sample_rate: Target sample rate.
:type target_sample_rate: int
:param filter: The resampling filter to use one of {'kaiser_best',
'kaiser_fast'}.
'kaiser_fast'}.
:type filter: str
"""
resample_ratio = target_sample_rate / self._sample_rate
self._samples = resampy.resample(
self.samples, self.sample_rate, target_sample_rate, filter=filter)
self._sample_rate = target_sample_rate
Expand Down Expand Up @@ -364,6 +365,31 @@ def pad_silence(self, duration, sides='both'):
raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples

def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0

def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries.
Expand Down Expand Up @@ -503,7 +529,7 @@ def add_noise(self,
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng)
noise_new.apply_gain(noise_gain_db)
noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new)

@property
Expand Down
9 changes: 6 additions & 3 deletions deep_speech_2/data_utils/augmentor/augmentation.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor
Expand Down Expand Up @@ -79,11 +80,13 @@ def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
if augmentor_type == "speed":
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
if augmentor_type == "resample":
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
if augmentor_type == "bayesian_normal":
elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
2 changes: 1 addition & 1 deletion deep_speech_2/data_utils/augmentor/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ def transform_audio(self, audio_segment):
:param audio: Audio segment to add effects to.
:type audio: AudioSegment|SpeechSegment
"""
audio_segment.resample(self._new_sample_rate)
audio_segment.resample(self._new_sample_rate)
34 changes: 34 additions & 0 deletions deep_speech_2/data_utils/augmentor/shift_perturb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Contains the volume perturb augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from data_utils.augmentor.base import AugmentorBase


class ShiftPerturbAugmentor(AugmentorBase):
"""Augmentation model for adding random shift perturbation.
:param rng: Random generator object.
:type rng: random.Random
:param min_shift_ms: Minimal shift in milliseconds.
:type min_shift_ms: float
:param max_shift_ms: Maximal shift in milliseconds.
:type max_shift_ms: float
"""

def __init__(self, rng, min_shift_ms, max_shift_ms):
self._min_shift_ms = min_shift_ms
self._max_shift_ms = max_shift_ms
self._rng = rng

def transform_audio(self, audio_segment):
"""Shift audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
audio_segment.shift(shift_ms)
19 changes: 10 additions & 9 deletions deep_speech_2/data_utils/augmentor/speed_perturb.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ class SpeedPerturbAugmentor(AugmentorBase):
:param rng: Random generator object.
:type rng: random.Random
:param min_speed_rate: Lower bound of new speed rate to sample.
:param min_speed_rate: Lower bound of new speed rate to sample and should
not below 0.9.
:type min_speed_rate: float
:param max_speed_rate: Upper bound of new speed rate to sample.
:param max_speed_rate: Upper bound of new speed rate to sample and should
not above 1.1.
:type max_speed_rate: float
"""

def __init__(self, rng, min_speed_rate, max_speed_rate):

if (min_speed_rate < 0.5):
raise ValueError("Sampling speed below 0.9 can cause unnatural "\
"effects")
if (max_speed_rate > 1.5):
raise ValueError("Sampling speed above 1.1 can cause unnatural "\
"effects")
if min_speed_rate < 0.9:
raise ValueError(
"Sampling speed below 0.9 can cause unnatural effects")
if max_speed_rate > 1.1:
raise ValueError(
"Sampling speed above 1.1 can cause unnatural effects")
self._min_speed_rate = min_speed_rate
self._max_speed_rate = max_speed_rate
self._rng = rng
Expand Down
2 changes: 1 addition & 1 deletion deep_speech_2/data_utils/augmentor/volume_perturb.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def transform_audio(self, audio_segment):
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain)
audio_segment.gain_db(gain)
Loading

0 comments on commit 9ec357e

Please sign in to comment.