Skip to content

Commit

Permalink
Swap to new text encoder model
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Aug 14, 2024
1 parent 76204b9 commit 363282a
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 70 deletions.
2 changes: 2 additions & 0 deletions surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def slice_bboxes_from_image(image: Image.Image, bboxes):
lines = []
for bbox in bboxes:
line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
if line.size[0] == 0:
print(f"Warning: found an empty line with bbox {bbox}")
lines.append(line)
return lines

Expand Down
87 changes: 86 additions & 1 deletion surya/model/recognition/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DonutSwinConfig(PretrainedConfig):

def __init__(
self,
image_size=(196, 896),
image_size=(256, 896),
patch_size=4,
num_channels=3,
embed_dim=128,
Expand Down Expand Up @@ -117,6 +117,8 @@ def __init__(
init_std=0.02,
tie_word_embeddings=False,
aux_heads=0, # How many n-token-ahead heads to add
encoder_hidden_size=1024,
causal=False,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
Expand Down Expand Up @@ -147,6 +149,8 @@ def __init__(
self.init_std = init_std
self.tie_word_embeddings = tie_word_embeddings
self.aux_heads = aux_heads
self.encoder_hidden_size = encoder_hidden_size
self.causal = causal

super().__init__(
pad_token_id=pad_token_id,
Expand All @@ -160,6 +164,87 @@ def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]


class SuryaOCRTextEncoderConfig(PretrainedConfig):
model_type = "surya_ocr"

def __init__(
self,
num_hidden_layers=10,
vocab_size=65792,
hidden_size=1024,
intermediate_size=4 * 1024,
num_attention_heads=16,
lru_width=None,
attention_window_size=16,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=1,
hidden_activation="gelu_pytorch_tanh",
rope_theta=10000.0,
block_types=("attention",),
cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
self_attn_layers=(0, 1, 3, 5, 7, 9),
global_attn_layers=(0, 1, 3, 5, 7, 9),
attention_dropout=0.0,
num_key_value_heads=2,
attention_bias=False,
w_init_variance_scale=0.01,
init_std=0.02,
tie_word_embeddings=False,
aux_heads=0, # How many n-token-ahead heads to add
encoder_hidden_size=1024,
iteration_count=1,
causal=False,
query_token_count=128,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.block_types = list(block_types)
self.hidden_activation = hidden_activation
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
self.cross_attn_layers = cross_attn_layers
self.self_attn_layers = self_attn_layers
self.global_attn_layers = global_attn_layers
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
self.w_init_variance_scale = w_init_variance_scale
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
self.init_std = init_std
self.tie_word_embeddings = tie_word_embeddings
self.aux_heads = aux_heads
self.encoder_hidden_size = encoder_hidden_size
self.iteration_count = iteration_count
self.causal = causal
self.query_token_count = query_token_count

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)

@property
def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]

TOTAL_TOKENS = 65536
TOKEN_OFFSET = 3 # Pad, eos, bos
Expand Down
72 changes: 67 additions & 5 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn
from transformers.utils import ModelOutput

from surya.model.recognition.config import SuryaOCRDecoderConfig
from surya.model.recognition.config import SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
Expand Down Expand Up @@ -134,8 +134,8 @@ def __init__(self, config: SuryaOCRDecoderConfig):
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = SuryaOCRDecoderRotaryEmbedding(
self.head_dim,
Expand All @@ -148,6 +148,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Encoder attention mask currently ignored

Expand All @@ -162,7 +163,8 @@ def forward(
value_states = self.v_proj(encoder_hidden_states)
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
self._update_cache(key_states, value_states)
if use_cache:
self._update_cache(key_states, value_states)
else:
key_states = self.key_states
value_states = self.value_states
Expand Down Expand Up @@ -232,6 +234,7 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Final is bsz, num_attention_heads, seq_len, head_dim
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -368,7 +371,7 @@ def forward(
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(activations)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
)
cross_attn_output = cross_attn_path + raw_activations
else:
Expand Down Expand Up @@ -448,6 +451,7 @@ def __init__(self, config: SuryaOCRDecoderConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.causal = config.causal

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -533,6 +537,9 @@ def forward(
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if not self.causal:
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
Expand Down Expand Up @@ -630,4 +637,59 @@ def forward(
logits=logits,
aux_logits=aux_logits,
hidden_states=outputs.hidden_states,
)

@dataclass
class TextEncoderOutput(CausalLMOutput):
hidden_states: torch.FloatTensor = None


class SuryaOCRTextEncoder(SuryaOCRDecoderPreTrainedModel):
_tied_weights_keys = None
config_class = SuryaOCRTextEncoderConfig

def __init__(self, config, **kwargs):
super().__init__(config)
self.model = SuryaOCRDecoderModel(config)
self.vocab_size = config.vocab_size

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def set_decoder(self, decoder):
self.model = decoder

def get_decoder(self):
return self.model

# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutput]:
outputs = self.model(
input_ids=input_ids,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_hidden_states=True,
return_dict=True,
)

return TextEncoderOutput(
hidden_states=outputs.last_hidden_state,
)
56 changes: 21 additions & 35 deletions surya/model/recognition/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
self.key = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)
self.value = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dropout_p = config.attention_probs_dropout_prob

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand All @@ -344,53 +344,39 @@ def forward(
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)

# Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
key_layer = self.transpose_kv_for_scores(self.key(hidden_states), self.kv_repeats)
value_layer = self.transpose_kv_for_scores(self.value(hidden_states), self.kv_repeats)
query_layer = self.transpose_for_scores(mixed_query_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

attention_scores = attention_scores / math.sqrt(self.attention_head_size)

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
if attention_mask is None:
attention_mask = relative_position_bias
else:
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
repeat_count = (batch_size // mask_shape)
attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
attention_mask = attention_mask + relative_position_bias

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer.contiguous(),
key_layer.contiguous(),
value_layer.contiguous(),
attn_mask=attention_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self.attention_head_size**-0.5,
)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, dim, num_channels)

outputs = (attn_output,)
return outputs


# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class DonutSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
Expand Down
8 changes: 7 additions & 1 deletion surya/model/recognition/encoderdecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
from surya.model.recognition.encoder import DonutSwinModel
from surya.model.recognition.decoder import SuryaOCRDecoder
from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder


class OCREncoderDecoderModel(PreTrainedModel):
Expand All @@ -20,6 +20,7 @@ def __init__(
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
text_encoder: Optional[PreTrainedModel] = None,
):
# initialize with config
# make sure input & output embeddings is not tied
Expand All @@ -33,13 +34,18 @@ def __init__(
if decoder is None:
decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation)

if text_encoder is None:
text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)

self.encoder = encoder
self.decoder = decoder
self.text_encoder = text_encoder

# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
self.text_encoder.config = self.config.text_encoder

def get_encoder(self):
return self.encoder
Expand Down
Loading

0 comments on commit 363282a

Please sign in to comment.