Skip to content

Commit

Permalink
🚨🚨🚨 TF: Remove TFWrappedEmbeddings (breaking: TF embedding initiali…
Browse files Browse the repository at this point in the history
…zation updated for encoder-decoder models) (huggingface#19263)

* added test

* correct embedding init

* some changes in blenderbot (incomplete)

* update blenderbot (diff to be used as reference)

* update blenderbot_small

* update LED

* update marian

* update T5 and remove TFWrappedEmbeddings

* nullcontext() -> ContextManagers()

* fix embedding init
  • Loading branch information
gante authored Oct 11, 2022
1 parent 8e4ee28 commit 462cd64
Show file tree
Hide file tree
Showing 18 changed files with 515 additions and 1,032 deletions.
33 changes: 0 additions & 33 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3038,36 +3038,3 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
`tf.initializers.TruncatedNormal`: The truncated normal initializer.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)


class TFWrappedEmbeddings:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
saving/storing the correct weights
"""

# TODO (joao): flagged for delection due to embeddings refactor

def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
self.vocab_size = self._layer.vocab_size

def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)

# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)

def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)

# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)
18 changes: 7 additions & 11 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import random
from contextlib import nullcontext
from typing import Optional, Tuple, Union

import numpy as np
Expand All @@ -41,6 +40,7 @@
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -741,11 +741,10 @@ def call(
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
Expand Down Expand Up @@ -945,11 +944,10 @@ def call(
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
Expand Down Expand Up @@ -1378,8 +1376,6 @@ def call(
return_dict=return_dict,
training=training,
)
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
Expand Down
116 changes: 63 additions & 53 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@
DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -119,7 +118,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFBlenderbotLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
Expand All @@ -133,8 +132,10 @@ def call(
"""Input is expected to be of size [bsz x seqlen]."""
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(position_ids)
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(tf.cast(position_ids, dtype=tf.int32))


# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot
Expand Down Expand Up @@ -638,7 +639,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
config: BlenderbotConfig
"""

def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = tf.keras.layers.Dropout(config.dropout)
Expand Down Expand Up @@ -726,17 +727,25 @@ def call(
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
Expand Down Expand Up @@ -805,7 +814,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
embed_tokens: output embedding
"""

def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
Expand Down Expand Up @@ -933,17 +942,21 @@ def call(
positions = self.embed_positions(input_shape, position_ids=position_ids)

if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

hidden_states = inputs_embeds

Expand Down Expand Up @@ -1037,32 +1050,25 @@ def __init__(self, config: BlenderbotConfig, **kwargs):
super().__init__(**kwargs)

self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")

with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass

# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size
self.shared = tf.keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared",
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared"

self.encoder = TFBlenderbotEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBlenderbotDecoder(config, embed_tokens, name="decoder")
self.encoder = TFBlenderbotEncoder(config, self.shared, name="encoder")
self.decoder = TFBlenderbotDecoder(config, self.shared, name="decoder")

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

@unpack_inputs
def call(
Expand Down Expand Up @@ -1284,7 +1290,6 @@ def __init__(self, config, *inputs, **kwargs):
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT

def get_decoder(self):
return self.model.decoder
Expand All @@ -1299,10 +1304,15 @@ def set_output_embeddings(self, value):
self.set_input_embeddings(value)

def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
return {"final_logits_bias": self.bias_layer.bias}

def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size = value["final_logits_bias"].shape[-1]
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
)
self.bias_layer.bias.assign(value["final_logits_bias"])

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
Expand Down Expand Up @@ -1385,7 +1395,7 @@ def call(
return_dict=return_dict,
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

Expand Down
Loading

0 comments on commit 462cd64

Please sign in to comment.