Skip to content

Commit

Permalink
transformers.fx.symbolic_trace supports inputs_embeds (huggingface#31574
Browse files Browse the repository at this point in the history
)

* symbolic trace supports inputs_embeds

* fix test?

* Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
fxmarty and amyeroberts authored Jul 8, 2024
1 parent e5ca9b0 commit ba74370
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,13 @@ def _generate_dummy_input(
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "inputs_embeds" in input_name:
batch_size = shape[0]
sequence_length = shape[-1]

inputs_dict[input_name] = torch.zeros(
batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device
)
elif "visual_feats" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
Expand Down
16 changes: 14 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
"input_features",
"input_ids",
"input_values",
"inputs_embeds",
"pixel_values",
"token_type_ids",
"visual_feats",
Expand Down Expand Up @@ -1214,16 +1215,27 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
)

if "inputs_embeds" in inspect.signature(model.forward).parameters:
inputs_to_test.append(
{
"inputs_embeds": torch.rand(
2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device
)
}
)

for inps in inputs_to_test:
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
input_names = list(filtered_inputs.keys())
input_names_to_trace = list(filtered_inputs.keys())

if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"

traced_model = symbolic_trace(model, input_names)
model.config.use_cache = "past_key_values" in input_names_to_trace

traced_model = symbolic_trace(model, input_names_to_trace)

with torch.no_grad():
traced_output = traced_model(**filtered_inputs)
Expand Down

0 comments on commit ba74370

Please sign in to comment.