Skip to content

Commit

Permalink
Merge: [FastPitch/PyT] Fix pitch-modification notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-kkudrynski committed Mar 1, 2022
2 parents 5bd2dee + 1083538 commit 20ab781
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch


Expand All @@ -11,7 +10,7 @@ def pitch_transform_custom(pitch, pitch_lens):
PARAMS
------
pitch: torch.Tensor (bs, max_len)
pitch: torch.Tensor (bs, 1, max_len)
Predicted pitch values for each lexical unit, padded to max_len (in Hz).
pitch_lens: torch.Tensor (bs, max_len)
Number of lexical units in each utterance.
Expand All @@ -22,7 +21,7 @@ def pitch_transform_custom(pitch, pitch_lens):
Modified pitch (in Hz).
"""

weights = torch.arange(pitch.size(1), dtype=torch.float32, device=pitch.device)
weights = torch.arange(pitch.size(2), dtype=torch.float32, device=pitch.device)

# The weights increase linearly from 0.0 to 1.0 in every i-th row
# in the range (0, pitch_lens[i])
Expand All @@ -31,4 +30,4 @@ def pitch_transform_custom(pitch, pitch_lens):
# Shift the range from (0.0, 1.0) to (0.5, 1.5)
weights += 0.5

return pitch * weights
return pitch * weights.unsqueeze(1)
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
unique_log_fpath)
from common.text import cmudict
from common.text.text_processing import TextProcessing
from pitch_transform import pitch_transform_custom
from fastpitch.pitch_transform import pitch_transform_custom
from waveglow import model as glow
from waveglow.denoiser import Denoiser

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,15 @@
"outputs": [],
"source": [
"! mkdir -p output\n",
"! MODEL_DIR='../pretrained_models' ../scripts/download_fastpitch.sh\n",
"! MODEL_DIR='../pretrained_models' ../scripts/download_waveglow.sh"
"\n",
"# Download grapheme-level model which will be easier to manipulate\n",
"! MODEL_ZIP=\"nvidia_fastpitch_200518.zip\" \\\n",
" MODEL=\"nvidia_fastpitch_200518.pt\" \\\n",
" MODEL_URL=\"https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip\" \\\n",
" MODEL_DIR='../pretrained_models/fastpitch' \\\n",
" ../scripts/download_fastpitch.sh\n",
"\n",
"! MODEL_DIR='../pretrained_models/waveglow' ../scripts/download_waveglow.sh"
]
},
{
Expand All @@ -123,8 +130,8 @@
"\n",
"# store paths in aux variables\n",
"fastp = '../pretrained_models/fastpitch/nvidia_fastpitch_200518.pt'\n",
"waveg = '../pretrained_models/waveglow/waveglow_1076430_14000_amp.pt'\n",
"flags = f'--cuda --fastpitch {fastp} --waveglow {waveg} --wn-channels 256'"
"waveg = '../pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt'\n",
"flags = f'--cuda --fastpitch {fastp} --waveglow {waveg} --wn-channels 256 --p-arpabet 0.0'"
]
},
{
Expand Down Expand Up @@ -176,7 +183,7 @@
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../pitch_transform.py\n",
"%%writefile ../fastpitch/pitch_transform.py\n",
"import torch\n",
"import numpy as np\n",
"\n",
Expand Down Expand Up @@ -216,7 +223,7 @@
},
"outputs": [],
"source": [
"# Synthesis with pace 0.75 and odd - even sentence transformation\n",
"# Synthesis with pace 0.75 and odd-even sentence transformation\n",
"!python ../inference.py {flags} -i text.txt -o output/custom --pitch-transform-custom --pace 0.75 > /dev/null\n",
"\n",
"IPython.display.Audio(\"output/custom/audio_0.wav\")"
Expand Down Expand Up @@ -257,7 +264,7 @@
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../pitch_transform.py\n",
"%%writefile ../fastpitch/pitch_transform.py\n",
"import torch\n",
"\n",
"def pitch_transform_custom(pitch, pitch_lens):\n",
Expand All @@ -266,7 +273,7 @@
" \n",
" # Put emphasis on `lly?` in 'Really?'\n",
" for i in range(len('Rea'), len('Really?')):\n",
" pitch[0][i] = 280 + (i - 3) * 20\n",
" pitch[0][0, i] = 280 + (i - 3) * 20\n",
"\n",
" return pitch"
]
Expand All @@ -290,7 +297,7 @@
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../pitch_transform.py\n",
"%%writefile ../fastpitch/pitch_transform.py\n",
"import torch\n",
"\n",
"def pitch_transform_custom(pitch, pitch_lens):\n",
Expand All @@ -299,7 +306,7 @@
" \n",
" # Fixed 'really' word adjustment\n",
" for i in range(len('Really?')):\n",
" pitch[0][i] = 215 - i * 10\n",
" pitch[0][0, i] = 215 - i * 10\n",
"\n",
" return pitch * torch.tensor(0.8)"
]
Expand Down Expand Up @@ -352,17 +359,17 @@
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../pitch_transform.py\n",
"%%writefile ../fastpitch/pitch_transform.py\n",
"import torch\n",
"\n",
"def pitch_transform_custom(pitch, pitch_lens):\n",
" \n",
" pitch[0][-6] = 180 # R\n",
" pitch[0][-5] = 260 # i\n",
" pitch[0][-4] = 360 # g\n",
" pitch[0][-3] = 360 # h\n",
" pitch[0][-2] = 380 # t\n",
" pitch[0][-1] = 400 # ?\n",
" pitch[0][0, -6] = 180 # R\n",
" pitch[0][0, -5] = 260 # i\n",
" pitch[0][0, -4] = 360 # g\n",
" pitch[0][0, -3] = 360 # h\n",
" pitch[0][0, -2] = 380 # t\n",
" pitch[0][0, -1] = 400 # ?\n",
"\n",
" return pitch * torch.tensor(0.9)"
]
Expand All @@ -383,7 +390,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -397,7 +404,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

set -e

: ${CMUDICT_DIR:="cmudict"}

echo "Downloading cmudict-0.7b ..."
wget https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b -qO cmudict/cmudict-0.7b
wget https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b -qO $CMUDICT_DIR/cmudict-0.7b
13 changes: 10 additions & 3 deletions PyTorch/SpeechSynthesis/FastPitch/scripts/download_fastpitch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

set -e

# # Grapheme-level w/o energy conditioning
# MODEL_ZIP="nvidia_fastpitch_200518.zip"
# MODEL="nvidia_fastpitch_200518.pt"
# MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
# MODEL_DIR="../pretrained_models/fastpitch"

# Phoneme-level w/ energy conditioning
: ${MODEL_DIR:="pretrained_models/fastpitch"}
MODEL_ZIP="nvidia_fastpitch_210824.zip"
MODEL="nvidia_fastpitch_210824.pt"
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1_1/versions/21.05.0/zip"
: ${MODEL_ZIP:="nvidia_fastpitch_210824.zip"}
: ${MODEL:="nvidia_fastpitch_210824.pt"}
: ${MODEL_URL:="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1_1/versions/21.05.0/zip"}

mkdir -p "$MODEL_DIR"

Expand Down

0 comments on commit 20ab781

Please sign in to comment.