Skip to content

Commit

Permalink
fix cast type in _se_pool_step_script_infer and _se_pool_step_script_…
Browse files Browse the repository at this point in the history
…train (NVIDIA#3239)

Signed-off-by: Oktai Tatanov <[email protected]>
  • Loading branch information
Oktai15 authored Nov 24, 2021
1 parent 6d3b257 commit 0c0afd4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,12 @@ def _se_pool_step_script_infer(x: torch.Tensor, context_window: int, mask: torch
"""
timesteps = x.shape[-1]
if timesteps < context_window:
y = torch.sum(x, dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True).type(x.dtype)
y = torch.sum(x, dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True).to(x.dtype)
else:
x = x[:, :, :context_window] # [B, C, context_window]
mask = mask[:, :, :context_window] # [B, 1, context_window]

mask = mask.sum(dim=-1, keepdim=True).type(x.dtype) # [B, C, 1]
mask = mask.sum(dim=-1, keepdim=True).to(x.dtype) # [B, C, 1]
y = x.sum(dim=-1, keepdim=True) # [B, 1, 1]
y = y / (mask + 1e-8) # [B, C, 1]

Expand All @@ -207,13 +207,13 @@ def _se_pool_step_script_train(x: torch.Tensor, context_window: int, mask: torch
"""
timesteps = x.shape[-1]
if timesteps < context_window:
y = torch.sum(x, dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True).type(x.dtype)
y = torch.sum(x, dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True).to(x.dtype)
else:
start_idx = torch.randint(0, timesteps - context_window, size=[1], dtype=torch.int32)[0]
x = x[:, :, start_idx : (start_idx + context_window)] # [B, C, context_window]
mask = mask[:, :, start_idx : (start_idx + context_window)] # [B, 1, context_window]

mask = mask.sum(dim=-1, keepdim=True).type(x.dtype) # [B, C, 1]
mask = mask.sum(dim=-1, keepdim=True).to(x.dtype) # [B, C, 1]
y = x.sum(dim=-1, keepdim=True) # [B, 1, 1]
y = y / (mask + 1e-8) # [B, C, 1]

Expand Down

0 comments on commit 0c0afd4

Please sign in to comment.