Skip to content

Commit

Permalink
[Fbsync] Lint fix (pytorch#1726)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Aug 26, 2021
1 parent 4915524 commit 560c082
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/pipeline_tacotron2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,4 @@ python inference.py --checkpoint-path ${model_path} \
--input-text "Hello world!" \
--text-preprocessor english_characters \
--output-path "./outputs.wav"
```
```
11 changes: 7 additions & 4 deletions test/torchaudio_unittest/common_utils/rnnt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from torchaudio.functional import rnnt_loss


CPU_DEVICE = torch.device("cpu")


class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -240,7 +243,7 @@ def get_basic_data(device):
def get_B1_T10_U3_D4_data(
random=False,
dtype=torch.float32,
device=torch.device("cpu"),
device=CPU_DEVICE,
):
B, T, U, D = 2, 10, 3, 4

Expand All @@ -263,7 +266,7 @@ def grad_hook(grad):
return data


def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
def get_B1_T2_U3_D5_data(dtype=torch.float32, device=CPU_DEVICE):
logits = torch.tensor(
[
0.1,
Expand Down Expand Up @@ -360,7 +363,7 @@ def grad_hook(grad):
return data, ref_costs, ref_gradients


def get_B2_T4_U3_D3_data(dtype=torch.float32, device=torch.device("cpu")):
def get_B2_T4_U3_D3_data(dtype=torch.float32, device=CPU_DEVICE):
# Test from D21322854
logits = torch.tensor(
[
Expand Down Expand Up @@ -550,7 +553,7 @@ def get_random_data(
max_D=40,
blank=-1,
dtype=torch.float32,
device=torch.device("cpu"),
device=CPU_DEVICE,
seed=None,
):
if seed is not None:
Expand Down

0 comments on commit 560c082

Please sign in to comment.