Skip to content

Commit

Permalink
[Flax] Correct shift labels for seq2seq models in Flax (huggingface#1…
Browse files Browse the repository at this point in the history
…2720)

* fix_torch_device_generate_test

* remove @

* push

* fix marian

* fix

* up
  • Loading branch information
patrickvonplaten authored Jul 15, 2021
1 parent 1a3deae commit 8244c5a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
12 changes: 7 additions & 5 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from functools import partial
from typing import Callable, Optional, Tuple

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -212,15 +214,15 @@
"""


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/marian/modeling_flax_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
19 changes: 10 additions & 9 deletions src/transformers/models/mbart/modeling_flax_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from functools import partial
from typing import Callable, Optional, Tuple

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models.
"""
prev_output_tokens = jnp.array(input_ids).clone()
prev_output_tokens = np.array(input_ids).copy()

assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."

# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)]
prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = np.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
).squeeze()
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
for i in range(prev_output_tokens.shape[1], 0, -1):
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1])
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens)

prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
prev_output_tokens[:, 0] = decoder_start_tokens

return prev_output_tokens

Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@
_TOKENIZER_FOR_DOC = "T5Tokenizer"


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down

0 comments on commit 8244c5a

Please sign in to comment.