Skip to content

Commit

Permalink
symbolic_trace: add past_key_values, llama, sdpa support (huggingface…
Browse files Browse the repository at this point in the history
…#28447)

* torch.fx: add pkv, llama, sdpa support

* Update src/transformers/models/opt/modeling_opt.py

* remove spaces

* trigger ci

* use explicit variable names
  • Loading branch information
fxmarty authored Jan 17, 2024
1 parent 09eb11a commit a6adc05
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 11 deletions.
17 changes: 10 additions & 7 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def to_4d(
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)

if causal_4d_mask is not None:
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)

Expand Down Expand Up @@ -346,10 +347,10 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape

# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)

if attention_mask is not None:
# 4d mask is passed through
Expand All @@ -367,10 +368,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
)
return attention_mask

elif torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
Expand Down Expand Up @@ -405,7 +404,11 @@ def _prepare_4d_causal_attention_mask_for_sdpa(

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if query_length > 1:
#
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
# controlflow that can not be captured properly.
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
if query_length > 1 and not is_tracing:
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
)
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _generate_supported_model_class_names(
"gptj",
"hubert",
"layoutlm",
"llama",
"lxmert",
"m2m_100",
"marian",
Expand All @@ -156,6 +157,8 @@ def _generate_supported_model_class_names(
# "xlnet",
]

_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]

_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
Expand Down Expand Up @@ -514,6 +517,14 @@ def torch_nn_functional_one_hot(tensor, num_classes=-1):
return torch.empty(shape, device="meta")


def torch_nn_functional_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
):
target_length = query.shape[-2]
head_dim = value.shape[-1]
return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")


def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
Expand Down Expand Up @@ -597,6 +608,7 @@ def to_concrete(t):
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
Expand Down Expand Up @@ -868,6 +880,23 @@ def _generate_dummy_input(
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
elif "past_key_values" in input_name:
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
raise NotImplementedError(
f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
)
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads

cache_shape = (shape[0], num_heads, 0, head_dim)
pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=device),
torch.rand(cache_shape, dtype=torch.float, device=device),
)
for i in range(model.config.num_hidden_layers)
)
inputs_dict[input_name] = pkv
else:
shape_with_hidden_size = shape + [model.config.hidden_size]
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
Expand Down
1 change: 1 addition & 0 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = True

def setUp(self):
self.model_tester = LlamaModelTester(self)
Expand Down
32 changes: 28 additions & 4 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
)

if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace


def _config_zero_init(config):
Expand Down Expand Up @@ -1004,7 +1004,9 @@ def test_torch_fx_output_loss(self):

def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
self.skipTest(
f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx"
)

configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
Expand Down Expand Up @@ -1060,6 +1062,26 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
if end_positions is not None:
input_names.append("end_positions")

if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
input_names.append("past_key_values")

# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if "past_key_values" not in inputs:
batch_size = inputs[next(iter(inputs))].shape[0]
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads

cache_shape = (batch_size, num_heads, 0, head_dim)
pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)
for i in range(model.config.num_hidden_layers)
)

inputs["past_key_values"] = pkv

filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())

Expand All @@ -1069,8 +1091,10 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
model.config.problem_type = "single_label_classification"

traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)

with torch.no_grad():
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)

except Exception as e:
self.fail(f"Couldn't trace module: {e}")
Expand Down

0 comments on commit a6adc05

Please sign in to comment.