Skip to content

Commit

Permalink
Add fastpath test for mask check flag (pytorch#82999)
Browse files Browse the repository at this point in the history
Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.

Test Plan: buck test caffe2/test:test_transformers

Differential Revision: D38259928

Pull Request resolved: pytorch#82999
Approved by: https://github.com/jbschlosser
  • Loading branch information
Yoav Navon authored and pytorchmergebot committed Aug 12, 2022
1 parent b60dc2e commit dfc97df
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import torch.nn as nn
import torch.nn.functional as F
import unittest
from unittest.mock import patch

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
TEST_FAIRSEQ,
run_tests,
parametrize,
instantiate_parametrized_tests,
freeze_rng_state,
TEST_WITH_CROSSREF
)
from torch.testing._internal.common_cuda import TEST_CUDA

if TEST_FAIRSEQ:
Expand Down Expand Up @@ -724,6 +731,59 @@ def rand_tensor(*shape):
if dropout_p == 0.0 or device == 'cpu':
self.assertEqual(actual, expected)

@unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
@torch.no_grad()
def test_mask_check_fastpath(self):
"""
Test that fastpath is executed independently of the mask that is passed.
If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath),
otherwise use fastpath with traditional tensors.
"""

x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)

def _test_fastpath(model, mask, mock_return_value, nested_tensors=True):
with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
fastpath_mock.return_value = mock_return_value
model(x, src_key_padding_mask=mask)

# If mock was called, fastpath was taken
self.assertTrue(fastpath_mock.called)

# If mock was called with nested tensors, sparsity fastpath was taken
for call_args, _ in fastpath_mock.call_args_list:
self.assertEqual(call_args[0].is_nested, nested_tensors)

encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)

model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
model.eval()

aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)

# Left aligned mask results in sparsity fastpath
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)

# Not aligned mask results in fastpath
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)

model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
model.eval()

# If nested tensor disabled, fastpath is always taken
_test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False)
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)


model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
model.eval()

# Mask check disabled results in sparisty fastpath, independently of the mask
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
_test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True)

# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
Expand Down

0 comments on commit dfc97df

Please sign in to comment.