Skip to content

Commit

Permalink
Making TF Lxmert model compliant with AMP (huggingface#10257)
Browse files Browse the repository at this point in the history
* Fix AMP

* Rework cast

* Apply style
  • Loading branch information
jplu authored Feb 19, 2021
1 parent d27b28d commit 2fc6284
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
27 changes: 16 additions & 11 deletions src/transformers/models/lxmert/modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,12 @@ def call(self, hidden_states, context, attention_mask, output_attentions, traini
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
# Apply the attention mask is (precomputed for all layers in TFLxmertModel call() function)
attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
Expand Down Expand Up @@ -721,6 +722,11 @@ def call(
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)

# Positional Word Embeddings
embedding_output = self.embeddings(
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
Expand All @@ -734,8 +740,10 @@ def call(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)

if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = tf.reshape(
Expand All @@ -745,16 +753,13 @@ def call(
tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1
)

extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype)
extended_visual_attention_mask = tf.multiply(
tf.subtract(one_cst, extended_visual_attention_mask), ten_thousand_cst
)
else:
extended_visual_attention_mask = None

# Positional Word Embeddings
embedding_output = self.embeddings(
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
)

# Run Lxmert encoder
encoder_outputs = self.encoder(
embedding_output,
Expand Down
4 changes: 0 additions & 4 deletions tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,6 @@ def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

def test_mixed_precision(self):
# TODO JP: Make Lxmert float16 compliant
pass

@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 2fc6284

Please sign in to comment.