diff --git a/README.md b/README.md index 792c462..3067cdb 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ There is a hosted API for all surya models available [here](https://www.datalab. I want surya to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage. -The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to). +The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to). # Installation @@ -84,12 +84,12 @@ surya_gui This command will write out a json file with the detected text and bboxes: ```shell -surya_ocr DATA_PATH --images --langs hi,en +surya_ocr DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs -- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages (I don't recommend using more than `4`). Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`. -- `--lang_file` if you want to use a different language for different PDFs/images, you can specify languages here. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`. +- `--langs` is an optional (but recommended) argument that specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`. +- `--lang_file` if you want to use a different language for different PDFs/images, you can optionally specify languages in a file. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`. - `--images` will save images of the pages and detected text lines (optional) - `--results_dir` specifies the directory to save results to instead of the default - `--max` specifies the maximum number of pages to process if you don't want to process everything @@ -108,21 +108,21 @@ The `results.json` file will contain a json dictionary where the keys are the in **Performance tips** -Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `50MB` of VRAM, so very high batch sizes are possible. The default is a batch size `256`, which will use about 12.8GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`. +Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `512`, which will use about 20GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`. ### From python ```python from PIL import Image from surya.ocr import run_ocr -from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor -from surya.model.recognition.model import load_model as load_recognition_model -from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor +from surya.model.recognition.model import load_model as load_rec_model +from surya.model.recognition.processor import load_processor as load_rec_processor image = Image.open(IMAGE_PATH) -langs = ["en"] # Replace with your languages -det_processor, det_model = load_detection_processor(), load_detection_model() -rec_model, rec_processor = load_recognition_model(), load_recognition_processor() +langs = ["en"] # Replace with your languages - optional but recommended +det_processor, det_model = load_det_processor(), load_det_model() +rec_model, rec_processor = load_rec_model(), load_rec_processor() predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) ``` @@ -134,7 +134,7 @@ The OCR model can be compiled to get an ~15% speedup in total inference time. T ```python import torch -rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder) +rec_model.decoder.model = torch.compile(rec_model.decoder.model) ``` ## Text line detection @@ -142,7 +142,7 @@ rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder) This command will write out a json file with the detected bboxes. ```shell -surya_detect DATA_PATH --images +surya_detect DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs @@ -184,7 +184,7 @@ predictions = batch_text_detection([image], model, processor) This command will write out a json file with the detected layout. ```shell -surya_layout DATA_PATH --images +surya_layout DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs @@ -231,7 +231,7 @@ layout_predictions = batch_layout_detection([image], model, processor, line_pred This command will write out a json file with the detected reading order and layout. ```shell -surya_order DATA_PATH --images +surya_order DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs @@ -417,7 +417,9 @@ python benchmark/recognition.py --tesseract - `--debug 2` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one - `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder. + - Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark. +- Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking. This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus). **Layout analysis** diff --git a/benchmark/recognition.py b/benchmark/recognition.py index 28f1be8..a71ba1e 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -30,6 +30,7 @@ def main(): parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None) parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28) parser.add_argument("--compile", action="store_true", help="Compile the model.", default=False) + parser.add_argument("--specify_language", action="store_true", help="Pass language codes into the model.", default=False) args = parser.parse_args() if args.compile: @@ -46,7 +47,7 @@ def main(): if args.langs: langs = args.langs.split(",") - dataset = dataset.filter(lambda x: x["language"] in langs) + dataset = dataset.filter(lambda x: x["language"] in langs, num_proc=4) images = list(dataset["image"]) images = convert_if_not_rgb(images) @@ -62,14 +63,17 @@ def main(): lang_list.append([l]) else: lang_list.append(l) + n_list = [None] * len(images) if args.compile: - rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder) + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 64 + rec_model.decoder.model = torch.compile(rec_model.decoder.model) # Run through one batch to compile the model run_recognition(images[:1], lang_list[:1], rec_model, rec_processor, bboxes=bboxes[:1]) start = time.time() - predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes) + predictions_by_image = run_recognition(images, lang_list if args.specify_language else n_list, rec_model, rec_processor, bboxes=bboxes) surya_time = time.time() - start surya_scores = defaultdict(list) @@ -84,9 +88,9 @@ def main(): flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]] benchmark_stats = { "surya": { - "avg_score": sum(flat_surya_scores) / len(flat_surya_scores), - "lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()}, - "time_per_img": surya_time / len(images) + "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)), + "lang_scores": {l: sum(scores) / max(1, len(scores)) for l, scores in surya_scores.items()}, + "time_per_img": surya_time / max(1, len(images)) } } @@ -134,7 +138,7 @@ def main(): json.dump(benchmark_stats, f) key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] - table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES + table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages table_data = [ ["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages], ] diff --git a/ocr_app.py b/ocr_app.py index e8f20e3..ff05cc6 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -125,7 +125,7 @@ def page_count(pdf_file): """) in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]) -languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=["English"], max_selections=4) +languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=[], max_selections=4, help="Select the languages in the image (if known) to improve OCR accuracy. Optional.") if in_file is None: st.stop() diff --git a/ocr_text.py b/ocr_text.py index e624c14..5b5bd65 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -1,6 +1,7 @@ import os import argparse import json +import time from collections import defaultdict import torch @@ -23,12 +24,11 @@ def main(): parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0) parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) - parser.add_argument("--langs", type=str, help="Language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) - parser.add_argument("--lang_file", type=str, help="Path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) + parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) + parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) + parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False) args = parser.parse_args() - assert args.langs or args.lang_file, "Must provide either --langs or --lang_file" - if os.path.isdir(args.input_path): images, names = load_from_folder(args.input_path, args.max, args.start_page) folder_name = os.path.basename(args.input_path) @@ -42,23 +42,29 @@ def main(): for lang in langs: replace_lang_with_code(lang) image_langs = langs - else: + elif args.langs: # We got our language settings from the input langs = args.langs.split(",") replace_lang_with_code(langs) image_langs = [langs] * len(images) + else: + image_langs = [None] * len(images) det_processor = load_detection_processor() det_model = load_detection_model() - _, lang_tokens = _tokenize("", get_unique_langs(image_langs)) - rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need + rec_model = load_recognition_model() rec_processor = load_recognition_processor() result_path = os.path.join(args.results_dir, folder_name) os.makedirs(result_path, exist_ok=True) + start = time.time() predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor) + if args.debug: + print(f"OCR took {time.time() - start:.2f} seconds") + max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines]) + print(f"Max chars: {max_chars}") if args.images: for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): diff --git a/pyproject.toml b/pyproject.toml index 26cd6bd..7f26923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.4.15" +version = "0.5.0" description = "OCR, layout, reading order, and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index 08e9852..a808635 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -18,7 +18,7 @@ def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE if batch_size is None: - batch_size = 6 + batch_size = 8 if settings.TORCH_DEVICE_MODEL == "mps": batch_size = 8 if settings.TORCH_DEVICE_MODEL == "cuda": diff --git a/surya/input/processing.py b/surya/input/processing.py index 9933279..f8eaf21 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -84,6 +84,8 @@ def slice_bboxes_from_image(image: Image.Image, bboxes): lines = [] for bbox in bboxes: line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) + if line.size[0] == 0: + print(f"Warning: found an empty line with bbox {bbox}") lines.append(line) return lines diff --git a/surya/languages.py b/surya/languages.py index 83667cf..d4bfbd4 100644 --- a/surya/languages.py +++ b/surya/languages.py @@ -1,4 +1,5 @@ CODE_TO_LANGUAGE = { + "_math": "Math", 'af': 'Afrikaans', 'am': 'Amharic', 'ar': 'Arabic', diff --git a/surya/model/recognition/config.py b/surya/model/recognition/config.py index 23d9bbf..9ed750b 100644 --- a/surya/model/recognition/config.py +++ b/surya/model/recognition/config.py @@ -1,15 +1,251 @@ -from transformers import T5Config, MBartConfig, DonutSwinConfig +from dataclasses import dataclass +import torch +from transformers import PretrainedConfig +from transformers.utils import ModelOutput -class MBartMoEConfig(MBartConfig): - pass +class SuryaOCRConfig(PretrainedConfig): + model_type = "vision-encoder-decoder" + is_composition = True -class VariableDonutSwinConfig(DonutSwinConfig): - pass + def __init__(self, **kwargs): + super().__init__(**kwargs) + encoder_config = kwargs.pop("encoder") + decoder_config = kwargs.pop("decoder") + + self.encoder = encoder_config + self.decoder = decoder_config + self.is_encoder_decoder = True + + if isinstance(decoder_config, dict): + self.decoder_start_token_id = decoder_config["bos_token_id"] + self.pad_token_id = decoder_config["pad_token_id"] + self.eos_token_id = decoder_config["eos_token_id"] + else: + self.decoder_start_token_id = decoder_config.bos_token_id + self.pad_token_id = decoder_config.pad_token_id + self.eos_token_id = decoder_config.eos_token_id + + +class DonutSwinConfig(PretrainedConfig): + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=(256, 896), + patch_size=4, + num_channels=3, + embed_dim=128, + depths=[2, 2, 14, 2], + num_heads=[4, 8, 16, 32], + num_kv_heads=[1, 2, 4, 8], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_length=256, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.encoder_length = encoder_length + + +class SuryaOCRDecoderConfig(PretrainedConfig): + model_type = "surya_ocr" + + def __init__( + self, + num_hidden_layers=10, + vocab_size=65792, + hidden_size=1024, + intermediate_size=4 * 1024, + num_attention_heads=16, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=1, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + self_attn_layers=(0, 1, 3, 5, 7, 9), + global_attn_layers=(0, 1, 3, 5, 7, 9), + attention_dropout=0.0, + num_key_value_heads=2, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + encoder_hidden_size=1024, + causal=False, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size = encoder_hidden_size + self.causal = causal + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] + + +class SuryaOCRTextEncoderConfig(PretrainedConfig): + model_type = "surya_ocr" + + def __init__( + self, + num_hidden_layers=10, + vocab_size=65792, + hidden_size=1024, + intermediate_size=4 * 1024, + num_attention_heads=16, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=1, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + self_attn_layers=(0, 1, 3, 5, 7, 9), + global_attn_layers=(0, 1, 3, 5, 7, 9), + attention_dropout=0.0, + num_key_value_heads=2, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + encoder_hidden_size=1024, + iteration_count=1, + causal=False, + query_token_count=128, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size = encoder_hidden_size + self.iteration_count = iteration_count + self.causal = causal + self.query_token_count = query_token_count + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] -# Config specific to the model, needed for the tokenizer TOTAL_TOKENS = 65536 TOKEN_OFFSET = 3 # Pad, eos, bos SPECIAL_TOKENS = 253 @@ -107,5 +343,6 @@ class VariableDonutSwinConfig(DonutSwinConfig): 'vi': 89, 'xh': 90, 'yi': 91, - 'zh': 92 + 'zh': 92, + "_math": 93 } \ No newline at end of file diff --git a/surya/model/recognition/decoder.py b/surya/model/recognition/decoder.py index fe2d4f4..071f701 100644 --- a/surya/model/recognition/decoder.py +++ b/surya/model/recognition/decoder.py @@ -1,511 +1,695 @@ -import copy -from typing import Optional, List, Union, Tuple +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union -from transformers import MBartForCausalLM, MBartConfig +import torch +import torch.utils.checkpoint from torch import nn +from transformers.utils import ModelOutput + +from surya.model.recognition.config import SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig +from transformers import PreTrainedModel from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions -from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder -from surya.model.recognition.config import MBartMoEConfig -import torch -import math +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from surya.settings import settings -class MBartLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ +_MAX_SQRT_GRADIENT = 1000.0 - def __init__(self, num_embeddings: int, embedding_dim: int): - # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): - """`input_ids' shape is expected to be [bsz x seqlen].""" +@dataclass +class OCRModelOutput(ModelOutput): + logits: torch.Tensor + aux_logits: torch.Tensor | None = None + hidden_states: torch.Tensor | None = None - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) - return super().forward(positions + self.offset) +class SuryaOCRDecoderRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) -class MBartExpertMLP(nn.Module): - def __init__(self, config: MBartConfig, is_lg=False, is_xl=False): - super().__init__() - self.ffn_dim = config.d_expert - if is_lg: - self.ffn_dim = config.d_expert_lg - if is_xl: - self.ffn_dim = config.d_expert_xl - self.hidden_dim = config.d_model + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst SuryaOCRDecoder is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.dropout = nn.Dropout(config.activation_dropout) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" - self.act_fn = ACT2FN[config.activation_function] - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states +ALL_LAYERNORM_LAYERS.append(SuryaOCRDecoderRMSNorm) -class MBartExpertLayer(nn.Module): - # From mixtral, with modifications - def __init__(self, config): +class SuryaOCRDecoderRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, device=None): super().__init__() - self.dropout = nn.Dropout(config.activation_dropout) - - self.hidden_dim = config.d_model + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaOCRDecoder + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed - self.lg_lang_codes = sorted(config.lg_langs.values()) if hasattr(config, "lg_langs") else [] - self.xl_lang_codes = sorted(config.xl_langs.values()) if hasattr(config, "xl_langs") else [] - self.lang_codes = sorted(config.langs.values()) - self.num_experts = len(self.lang_codes) +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - self.experts = nn.ModuleDict({str(lang): MBartExpertMLP(config, is_lg=(lang in self.lg_lang_codes), is_xl=(lang in self.xl_lang_codes)) for lang in self.lang_codes}) - def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape +class SuryaOCRDecoderSdpaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper + Modified for GQA + """ - final_hidden_states = torch.zeros( - (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + def __init__(self, config: SuryaOCRDecoderConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaOCRDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, ) - # Weight experts based on how many languages in the input - routing_weights = 1 / ((langs > 3).sum(axis=-1)) - # Set weights to 1 if zero experts activated - routing_weights[torch.isinf(routing_weights)] = 1 - - unique_langs = langs.unique(dim=None, sorted=True) - unique_langs = unique_langs[unique_langs > 3] # Remove start token + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Encoder attention mask currently ignored - # Loop over all available experts in the model and perform the computation on each expert - for expert_lang in unique_langs: - # Check which samples match with this expert - lang_match = (langs == expert_lang).any(dim=-1) - idx = torch.nonzero(lang_match, as_tuple=True)[0] + bsz, q_len, _ = hidden_states.size() + _, v_len, _ = encoder_hidden_states.size() - if idx.shape[0] == 0: - continue + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) - expert_layer = self.experts[str(expert_lang.item())] + if self.key_states is None: + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if use_cache: + self._update_cache(key_states, value_states) + else: + key_states = self.key_states + value_states = self.value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=None, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) - current_state = hidden_states[idx] - current_hidden_states = expert_layer(current_state.view(-1, hidden_dim)) - current_hidden_states = current_hidden_states.view(-1, sequence_length, hidden_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output - # Weight by number of languages in the input - selected_routing_weights = routing_weights[idx].view(-1, 1, 1) - current_hidden_states *= selected_routing_weights + def _setup_cache(self, batch_size, device, dtype=None): + # Setup initial caches + self.value_states = None + self.key_states = None - final_hidden_states.index_add_(0, idx, current_hidden_states) + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + self.value_states = value_states + self.key_states = key_states - return final_hidden_states +class SuryaOCRDecoderSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" -class MBartGQAttention(nn.Module): - def __init__( - self, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, - config: Optional[MBartConfig] = None, - ): + def __init__(self, config: SuryaOCRDecoderConfig): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.num_kv_groups = self.num_heads // self.num_kv_heads - - self.dropout = dropout - self.head_dim = embed_dim // num_heads self.config = config - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal - - self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaOCRDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, + ) def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - is_prefill: Optional[bool] = False, + position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + window_attn: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if is_cross_attention: - if is_prefill: - # cross_attentions - key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) - past_key_value = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) - else: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - past_key_value = None - # Self-attention - else: - if is_prefill: - # initial prompt - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0) - else: - # reuse k, v, self_attention - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Final is bsz, num_attention_heads, seq_len, head_dim + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if use_cache and hasattr(self, "key_states"): + cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn} + key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + # Mask is batch, head, seq_len, kv_len + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + current_cache_position = cache_position[-1].item() if cache_position is not None else None + if current_cache_position and settings.RECOGNITION_STATIC_CACHE: + # Mask out future cache positions + position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device) + position_mask[:, :, :, :current_cache_position + 1] = False + causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output - # Expand kv heads, then match query shape - key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape) - value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape) + def _setup_cache(self, batch_size, device, dtype=None): + if dtype is None and self.config.torch_dtype is not None: + dtype = self.config.torch_dtype + dtype = dtype if dtype is not None else torch.float32 - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + # Setup initial caches + self.value_states = None + self.key_states = None - if not is_cross_attention: - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if settings.RECOGNITION_STATIC_CACHE: + cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim) + self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) + self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + def _update_static_cache(self, key_states, value_states, **cache_kwargs): + cache_position = cache_kwargs.get("cache_position") + k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) - attn_output = torch.bmm(attn_weights, value_states).view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1,2) + k_out[:, :, cache_position] = key_states.to(k_out.dtype) + v_out[:, :, cache_position] = value_states.to(v_out.dtype) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) + self.key_states, self.value_states = k_out, v_out + return k_out, v_out - return attn_output, past_key_value + def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): + k_out = key_states + if self.key_states is not None: + k_out = torch.cat([self.key_states, key_states], dim=2) + v_out = value_states + if self.value_states is not None: + v_out = torch.cat([self.value_states, value_states], dim=2) -class MBartMoEDecoderLayer(nn.Module): - def __init__(self, config: MBartConfig, has_moe=False): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = MBartGQAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - num_kv_heads=config.kv_heads, - dropout=config.attention_dropout, - is_decoder=True, - is_causal=True, - config=config, - ) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBartGQAttention( - self.embed_dim, - config.decoder_attention_heads, - num_kv_heads=config.kv_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.has_moe = has_moe - if has_moe: - self.moe = MBartExpertLayer(config) - else: - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.key_states, self.value_states = k_out, v_out + return k_out, v_out - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - langs: Optional[torch.LongTensor] = None, - self_kv_cache: Optional[torch.Tensor] = None, - cross_kv_cache: Optional[torch.Tensor] = None, - is_prefill: Optional[bool] = False, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = True, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_kv_cache, - is_prefill=is_prefill, - attention_mask=attention_mask, - ) - hidden_states = residual + hidden_states - - # Cross-Attention Block - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - hidden_states, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - is_prefill=is_prefill, - attention_mask=encoder_attention_mask, - past_key_value=cross_kv_cache, - ) - hidden_states = residual + hidden_states + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + if settings.RECOGNITION_STATIC_CACHE: + return self._update_static_cache(key_states, value_states, **cache_kwargs) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = (present_key_value, cross_attn_present_key_value) + return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - if self.has_moe: - hidden_states = self.moe(hidden_states, langs) - else: - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.fc2(hidden_states) - hidden_states = residual + hidden_states +class SuryaOCRDecoderMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class SuryaOCRDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + super().__init__() + self.cross_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.temporal_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - outputs = (hidden_states,) + self.temporal_block = None + if layer_idx in config.self_attn_layers: + self.temporal_block = SuryaOCRDecoderSdpaAttention(config) - if use_cache: - outputs += (present_key_value,) + self.cross_attn_block = None + if layer_idx in config.cross_attn_layers: + self.cross_attn_block = SuryaOCRDecoderSdpaCrossAttention(config) - return outputs + self.window_attn = layer_idx not in config.global_attn_layers + self.channel_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp_block = SuryaOCRDecoderMlp(config) + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + cache_position: torch.Tensor = None, + use_cache: bool = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raw_activations = activations + + if self.cross_attn_block is not None: + # Do cross-attention on encoder outputs + cross_attn_inputs = self.cross_pre_norm(activations) + cross_attn_path = self.cross_attn_block( + cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache + ) + cross_attn_output = cross_attn_path + raw_activations + else: + cross_attn_output = raw_activations -class MBartMoEDecoder(MBartDecoder): - def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): - MBartPreTrainedModel.__init__(self, config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + if self.temporal_block is not None: + inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences + hidden_states = self.temporal_block( + inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn + ) - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + residual = hidden_states + raw_activations + else: + residual = cross_attn_output + + hidden_states = self.channel_pre_norm(residual) + hidden_states = self.mlp_block(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +class SuryaOCRDecoderPreTrainedModel(PreTrainedModel): + config_class = SuryaOCRDecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SuryaOCRDecoderLayer"] + _skip_keys_device_placement = ["cache"] + _supports_flash_attn_2 = False + _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + _supports_quantized_cache = True + + def _init_weights(self, module): + if isinstance(module, SuryaOCRDecoderSdpaAttention): + torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std) + + torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std) + elif isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, config, batch, device, dtype): + layers = getattr(self, "model", self).layers + for layer in layers: + if layer.temporal_block: + layer.temporal_block._setup_cache(batch, device, dtype) + if layer.cross_attn_block: + layer.cross_attn_block._setup_cache(batch, device, dtype) + + def reset_cache(self, batch, device, dtype): + pass + + def _tie_weights(self): + pass + + def tie_weights(self): + pass + + +class SuryaOCRDecoderModel(SuryaOCRDecoderPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaOCRDecoderDecoderLayer`] - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight + Args: + config: SuryaOCRDecoderConfig + """ - self.embed_positions = MBartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - # Language-specific MoE goes at second and second-to-last layer - self.layers = nn.ModuleList([MBartMoEDecoderLayer(config, has_moe=(i in config.moe_layers) and config.use_moe) for i in range(config.decoder_layers)]) - self.layernorm_embedding = nn.LayerNorm(config.d_model) - self.layer_norm = nn.LayerNorm(config.d_model) + def __init__(self, config: SuryaOCRDecoderConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.causal = config.causal + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [SuryaOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + + self.register_buffer( + "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False + ) # Initialize weights and apply final processing self.post_init() + # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value + def forward( self, input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - self_kv_cache: Optional[torch.Tensor] = None, - cross_kv_cache: Optional[torch.Tensor] = None, - past_token_count: Optional[int] = None, - langs: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - use_cache = True - return_dict = True - - input = input_ids - input_shape = input.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - # past_key_values_length - past_key_values_length = past_token_count if self_kv_cache is not None else 0 - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = self.embed_positions(input, past_key_values_length) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - all_hidden_states = None - all_self_attns = None - all_cross_attentions = None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - is_prefill = past_token_count == 0 - layer_self_kv_cache = self_kv_cache[idx] if self_kv_cache is not None else None - layer_cross_kv_cache = cross_kv_cache[idx] if cross_kv_cache is not None else None - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - langs=langs, - self_kv_cache=layer_self_kv_cache, - cross_kv_cache=layer_cross_kv_cache, - is_prefill=is_prefill, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=None, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + prefill: bool = False + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if use_cache: - next_decoder_cache += (layer_outputs[1],) + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and prefill: + self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - hidden_states = self.layer_norm(hidden_states) + all_hidden_states = () if output_hidden_states else None + for i, residual_block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache + ) + else: + hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache) + + hidden_states = self.final_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, - past_key_values=next_cache, hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, ) + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Ignore copy + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if not self.causal: + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = max(settings.RECOGNITION_MAX_TOKENS, sequence_length) + + diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = diagonal + if sequence_length != 1: + # Select the upper triangular part of the matrix, but unmask current token (the diagonal) + # triu will be the min_dtype, everything else is 0 (attended to) + causal_mask = torch.triu(diagonal, diagonal=1) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + # Mask positions in the causal mask that are masked in the attention mask + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if attention_mask is not None and attention_mask.device.type == "cuda": + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class SuryaOCRDecoder(SuryaOCRDecoderPreTrainedModel): + _tied_weights_keys = None -class MBartMoEDecoderWrapper(MBartPreTrainedModel): - """ - This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is - used in combination with the [`EncoderDecoderModel`] framework. - """ - - def __init__(self, config): + def __init__(self, config, **kwargs): super().__init__(config) - self.decoder = MBartMoEDecoder(config) + self.model = SuryaOCRDecoderModel(config) + self.vocab_size = config.vocab_size + aux_heads = config.aux_heads if config.aux_heads is not None else 0 + lm_heads = aux_heads + 1 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * lm_heads, bias=False) + + # Initialize weights and apply final processing + self.post_init() - def forward(self, *args, **kwargs): - return self.decoder(*args, **kwargs) + def get_input_embeddings(self): + return self.model.embed_tokens + def set_input_embeddings(self, value): + self.model.embed_tokens = value -class MBartMoE(MBartForCausalLM): - config_class = MBartMoEConfig - _tied_weights_keys = ["lm_head.weight"] + def get_output_embeddings(self): + return self.lm_head - def __init__(self, config, **kwargs): - config = copy.deepcopy(config) - config.is_decoder = True - config.is_encoder_decoder = False - MBartPreTrainedModel.__init__(self, config) - self.model = MBartMoEDecoderWrapper(config) + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + def set_decoder(self, decoder): + self.model = decoder - # Initialize weights and apply final processing - self.post_init() + def get_decoder(self): + return self.model + # Ignore copy def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - self_kv_cache: Optional[torch.FloatTensor] = None, - cross_kv_cache: Optional[torch.FloatTensor] = None, - past_token_count: Optional[int] = None, - langs: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + prefill: bool = False, **kwargs - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( + ) -> Union[Tuple, OCRModelOutput]: + outputs = self.model( input_ids=input_ids, + cache_position=cache_position, attention_mask=attention_mask, - self_kv_cache=self_kv_cache, - cross_kv_cache=cross_kv_cache, - past_token_count=past_token_count, - langs=langs, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + prefill=prefill, ) - logits = self.lm_head(outputs[0]) - - if not return_dict: - output = (logits,) + outputs[1:] - return output + hidden_states = outputs[0] + all_logits = self.lm_head(hidden_states) + all_logits = torch.split(all_logits, self.vocab_size, dim=-1) + logits = all_logits[0] + aux_logits = all_logits[1:] if len(all_logits) > 1 else None - return CausalLMOutputWithCrossAttentions( - loss=None, + return OCRModelOutput( logits=logits, - past_key_values=outputs.past_key_values, + aux_logits=aux_logits, hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, ) - def prune_moe_experts(self, keep_keys: List[int]): - # Remove experts not specified in keep_keys - str_keep_keys = [str(key) for key in keep_keys] - for layer in self.model.decoder.layers: - if not layer.has_moe: - continue - - lang_keys = list(layer.moe.experts.keys()) - for lang in lang_keys: - if lang not in str_keep_keys: - layer.moe.experts.pop(lang) - layer.lang_codes = keep_keys +@dataclass +class TextEncoderOutput(CausalLMOutput): + hidden_states: torch.FloatTensor = None + + +class SuryaOCRTextEncoder(SuryaOCRDecoderPreTrainedModel): + _tied_weights_keys = None + config_class = SuryaOCRTextEncoderConfig + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = SuryaOCRDecoderModel(config) + self.vocab_size = config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutput]: + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + ) + + return TextEncoderOutput( + hidden_states=outputs.last_hidden_state, + ) \ No newline at end of file diff --git a/surya/model/recognition/encoder.py b/surya/model/recognition/encoder.py index f01f35c..85fb01c 100644 --- a/surya/model/recognition/encoder.py +++ b/surya/model/recognition/encoder.py @@ -1,42 +1,133 @@ -from torch import nn -import torch +""" EfficientViT (by MIT Song Han's Lab) + +Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` + - https://arxiv.org/abs/2205.14756 + +Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py +Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit +""" + +import collections.abc +import math +from dataclasses import dataclass from typing import Optional, Tuple, Union -from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ - DonutSwinEncoder, DonutSwinModelOutput, DonutSwinEncoderOutput, DonutSwinAttention, DonutSwinDropPath, \ - DonutSwinIntermediate, DonutSwinOutput, window_partition, window_reverse +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from transformers.utils import ModelOutput +from surya.model.recognition.config import DonutSwinConfig + +_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin +class DonutSwinEncoderOutput(ModelOutput): + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DonutSwinModelOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + -from surya.model.recognition.config import VariableDonutSwinConfig +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows -class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin +class DonutSwinEmbeddings(nn.Module): """ Construct the patch and position embeddings. Optionally, also the mask token. """ def __init__(self, config, use_mask_token=False): - super().__init__(config, use_mask_token) + super().__init__() self.patch_embeddings = DonutSwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None - self.position_embeddings = None if config.use_absolute_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - + _, num_channels, height, width = pixel_values.shape embeddings, output_dimensions = self.patch_embeddings(pixel_values) - # Layernorm across the last dimension (each patch is a single row) embeddings = self.norm(embeddings) - batch_size, seq_len, embed_dim = embeddings.size() + batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) @@ -45,14 +136,62 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings[:, :seq_len, :] + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings[:, :seq_len] embeddings = self.dropout(embeddings) return embeddings, output_dimensions -class VariableDonutSwinPatchMerging(nn.Module): +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin +class DonutSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class DonutSwinPatchMerging(nn.Module): """ Patch Merging Layer. @@ -106,15 +245,230 @@ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] return input_feature -class VariableDonutSwinLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class DonutSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin +class DonutSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, num_kv_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.num_kv_heads = num_kv_heads + self.kv_repeats = self.num_attention_heads // self.num_kv_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.kv_head_size = self.num_kv_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias) + + self.dropout_p = config.attention_probs_dropout_prob + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def transpose_kv_for_scores(self, x, repeats): + new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size) + x = x.view(new_x_shape) + x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim + return x.permute(0, 2, 1, 3).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + # Final is (batch_size, num_attention_heads, seq_len, attention_head_size) + key_layer = self.transpose_kv_for_scores(self.key(hidden_states), self.kv_repeats) + value_layer = self.transpose_kv_for_scores(self.value(hidden_states), self.kv_repeats) + query_layer = self.transpose_for_scores(mixed_query_layer) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + if attention_mask is None: + attention_mask = relative_position_bias + else: + mask_shape = attention_mask.shape[0] + repeat_count = (batch_size // mask_shape) + attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1) + attention_mask = attention_mask + relative_position_bias + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer.contiguous(), + key_layer.contiguous(), + value_layer.contiguous(), + attn_mask=attention_mask, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self.attention_head_size**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, dim, num_channels) + + outputs = (attn_output,) + return outputs + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput +class DonutSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin +class DonutSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, num_kv_heads, window_size): + super().__init__() + self.self = DonutSwinSelfAttention(config, dim, num_heads, num_kv_heads, window_size) + self.output = DonutSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate +class DonutSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput +class DonutSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin +class DonutSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) - self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) + self.attention = DonutSwinAttention(config, dim, num_heads, num_kv_heads, window_size=self.window_size) self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DonutSwinIntermediate(config, dim) @@ -123,13 +477,15 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) - def get_attn_mask(self, height, width, dtype): + def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -194,9 +550,9 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) - if attn_mask is not None: - attn_mask = attn_mask.to(hidden_states_windows.device) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions @@ -229,19 +585,21 @@ def forward( return layer_outputs -class VariableDonutSwinStage(nn.Module): - def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin +class DonutSwinStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, num_kv_heads, drop_path, downsample): super().__init__() self.config = config self.dim = dim self.blocks = nn.ModuleList( [ - VariableDonutSwinLayer( + DonutSwinLayer( config=config, dim=dim, input_resolution=input_resolution, num_heads=num_heads, - shift_size=0 if (i % 2 == 0) else int(config.window_size // 2), + num_kv_heads=num_kv_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) ] @@ -288,7 +646,8 @@ def forward( return stage_outputs -class VariableDonutSwinEncoder(nn.Module): +# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin +class DonutSwinEncoder(nn.Module): def __init__(self, config, grid_size): super().__init__() self.num_layers = len(config.depths) @@ -296,14 +655,15 @@ def __init__(self, config, grid_size): dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] self.layers = nn.ModuleList( [ - VariableDonutSwinStage( + DonutSwinStage( config=config, dim=int(config.embed_dim * 2**i_layer), input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), depth=config.depths[i_layer], num_heads=config.num_heads[i_layer], + num_kv_heads=config.num_kv_heads[i_layer], drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], - downsample=VariableDonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, ) for i_layer in range(self.num_layers) ] @@ -389,22 +749,58 @@ def forward( ) -class VariableDonutSwinModel(DonutSwinModel): - config_class = VariableDonutSwinConfig - def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin +class DonutSwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DonutSwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["DonutSwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DonutSwinModel(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): super().__init__(config) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) - self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = VariableDonutSwinEncoder(config, self.embeddings.patch_grid) + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) - self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + def forward( self, pixel_values: Optional[torch.FloatTensor] = None, @@ -412,8 +808,8 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, - **kwargs ) -> Union[Tuple, DonutSwinModelOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): @@ -435,7 +831,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -446,22 +844,9 @@ def forward( return_dict=return_dict, ) - sequence_output = encoder_outputs[0] - - pooled_output = None - if self.pooler is not None: - pooled_output = self.pooler(sequence_output.transpose(1, 2)) - pooled_output = torch.flatten(pooled_output, 1) - - if not return_dict: - output = (sequence_output, pooled_output) + encoder_outputs[1:] - - return output + last_hidden_state = encoder_outputs[0] + last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] return DonutSwinModelOutput( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, - ) + last_hidden_state=last_hidden_state, + ) \ No newline at end of file diff --git a/surya/model/recognition/encoderdecoder.py b/surya/model/recognition/encoderdecoder.py new file mode 100644 index 0000000..0ec83ef --- /dev/null +++ b/surya/model/recognition/encoderdecoder.py @@ -0,0 +1,145 @@ +from typing import Optional, Union, Tuple + +import torch +from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput +from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right +from surya.model.recognition.encoder import DonutSwinModel +from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder + + +class OCREncoderDecoderModel(PreTrainedModel): + config_class = VisionEncoderDecoderConfig + base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + text_encoder: Optional[PreTrainedModel] = None, + ): + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + config.decoder.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = DonutSwinModel(config.encoder) + + if decoder is None: + decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation) + + if text_encoder is None: + text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation) + + self.encoder = encoder + self.decoder = decoder + self.text_encoder = text_encoder + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + self.text_encoder.config = self.config.text_encoder + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_cache_position: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # else: + encoder_attention_mask = None + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + cache_position=decoder_cache_position, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + **kwargs_decoder, + ) + + return Seq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + encoder_last_hidden_state=encoder_outputs.last_hidden_state + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + "encoder_outputs": encoder_outputs, + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, + } + return input_dict + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) \ No newline at end of file diff --git a/surya/model/recognition/model.py b/surya/model/recognition/model.py index 1ee2563..fcf9bd7 100644 --- a/surya/model/recognition/model.py +++ b/surya/model/recognition/model.py @@ -8,57 +8,42 @@ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) from typing import List, Optional, Tuple -from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM -from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig -from surya.model.recognition.encoder import VariableDonutSwinModel -from surya.model.recognition.decoder import MBartMoE +from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel +from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig +from surya.model.recognition.encoder import DonutSwinModel +from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder from surya.settings import settings +if not settings.ENABLE_EFFICIENT_ATTENTION: + print("Efficient attention is disabled. This will use significantly more VRAM.") + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) -def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None): - config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) - # Prune moe experts that are not needed before loading the model - if langs is not None: - config.decoder.langs = {lang_iso : lang_int for lang_iso, lang_int in config.decoder.langs.items() if lang_int in langs} +def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): - decoder_config = vars(config.decoder) - decoder = MBartMoEConfig(**decoder_config) + config = SuryaOCRConfig.from_pretrained(checkpoint) + decoder_config = config.decoder + decoder = SuryaOCRDecoderConfig(**decoder_config) config.decoder = decoder - encoder_config = vars(config.encoder) - encoder = VariableDonutSwinConfig(**encoder_config) + encoder_config = config.encoder + encoder = DonutSwinConfig(**encoder_config) config.encoder = encoder - # Get transformers to load custom encoder/decoder - AutoModel.register(MBartMoEConfig, MBartMoE) - AutoModelForCausalLM.register(MBartMoEConfig, MBartMoE) - AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) + text_encoder_config = config.text_encoder + text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) + config.text_encoder = text_encoder - model = LangVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) - assert isinstance(model.decoder, MBartMoE) - assert isinstance(model.encoder, VariableDonutSwinModel) + model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) + + assert isinstance(model.decoder, SuryaOCRDecoder) + assert isinstance(model.encoder, DonutSwinModel) + assert isinstance(model.text_encoder, SuryaOCRTextEncoder) model = model.to(device) model = model.eval() - print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") - return model - - -class LangVisionEncoderDecoderModel(VisionEncoderDecoderModel): - def prepare_inputs_for_generation( - self, input_ids, decoder_langs=None, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs - ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, langs=decoder_langs, past_key_values=past_key_values) - decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None - input_dict = { - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_input_ids": decoder_inputs["input_ids"], - "encoder_outputs": encoder_outputs, - "past_key_values": decoder_inputs["past_key_values"], - "use_cache": use_cache, - "decoder_langs": decoder_inputs["langs"], - } - return input_dict + print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") + return model \ No newline at end of file diff --git a/surya/model/recognition/processor.py b/surya/model/recognition/processor.py index 1e5193a..7aa20a2 100644 --- a/surya/model/recognition/processor.py +++ b/surya/model/recognition/processor.py @@ -31,19 +31,9 @@ def __init__(self, *args, max_size=None, train=False, **kwargs): @classmethod def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): - height, width = image.shape[:2] max_width, max_height = size["width"], size["height"] - if (height == max_height and width <= max_width) or (width == max_width and height <= max_height): - image = image.transpose(2, 0, 1) - return image - - scale = min(max_width / width, max_height / height) - - new_width = int(width * scale) - new_height = int(height * scale) - - resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation) + resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation) resized_image = resized_image.transpose(2, 0, 1) return resized_image @@ -191,7 +181,7 @@ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs): def __call__(self, *args, **kwargs): images = kwargs.pop("images", None) text = kwargs.pop("text", None) - lang = kwargs.pop("lang", None) + langs = kwargs.pop("langs", None) if len(args) > 0: images = args[0] @@ -204,7 +194,7 @@ def __call__(self, *args, **kwargs): inputs = self.image_processor(images, *args, **kwargs) if text is not None: - encodings = self.tokenizer(text, lang, **kwargs) + encodings = self.tokenizer(text, langs, **kwargs) if text is None: return inputs diff --git a/surya/model/recognition/tokenizer.py b/surya/model/recognition/tokenizer.py index 27c062c..f4201cd 100644 --- a/surya/model/recognition/tokenizer.py +++ b/surya/model/recognition/tokenizer.py @@ -1,5 +1,7 @@ from itertools import chain -from typing import List, Union +import random +from typing import List, Optional, Tuple, Union +from tokenizers import AddedToken from transformers import ByT5Tokenizer import numpy as np import torch @@ -31,19 +33,18 @@ def utf16_numbers_to_text(numbers): return text -def _tokenize(text: str, langs: List[str], eos_token_id: int = 1, add_eos: bool = True, add_bos: bool = True): +def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True): tokens = text_to_utf16_numbers(text) tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens lang_list = [] - for lang in langs: - code = LANGUAGE_MAP[lang] - lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) + if langs: + for lang in langs: + code = LANGUAGE_MAP[lang] + lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) tokens = lang_list + tokens - if add_eos: - tokens.append(eos_token_id) if add_bos: tokens.insert(0, eos_token_id) @@ -73,7 +74,7 @@ def __init__(self, super().__init__() - def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], List[str]], pad_token_id: int = 0, **kwargs): + def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs): tokenized = [] all_langs = [] @@ -83,10 +84,12 @@ def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], L texts = [texts] is_list = False + if langs is None: + langs = [None] * len(texts) + if isinstance(langs[0], str): langs = [langs] - # One language input per text input assert len(langs) == len(texts) for text, lang in zip(texts, langs): diff --git a/surya/ocr.py b/surya/ocr.py index 1744098..e020e11 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List from PIL import Image @@ -8,7 +9,7 @@ from surya.schema import TextLine, OCRResult -def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: +def run_recognition(images: List[Image.Image], langs: List[List[str] | None], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: # Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format assert bboxes is not None or polygons is not None assert len(images) == len(langs), "You need to pass in one list of languages for each image" @@ -25,7 +26,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model slices = slice_bboxes_from_image(image, bboxes[idx]) slice_map.append(len(slices)) all_slices.extend(slices) - all_langs.extend([lang] * len(slices)) + all_langs.extend([deepcopy(lang)] * len(slices)) rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) @@ -59,7 +60,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model return predictions_by_image -def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]: +def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]: images = convert_if_not_rgb(images) det_predictions = batch_text_detection(images, det_model, det_processor) diff --git a/surya/ordering.py b/surya/ordering.py index 0b87ba1..820bd74 100644 --- a/surya/ordering.py +++ b/surya/ordering.py @@ -38,12 +38,13 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVi if batch_size is None: batch_size = get_batch_size() - images = [image.convert("RGB") for image in images] # also copies the images output_order = [] for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): batch_bboxes = deepcopy(bboxes[i:i+batch_size]) batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + orig_sizes = [image.size for image in batch_images] model_inputs = processor(images=batch_images, boxes=batch_bboxes) diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py index 9cc14cb..c1ed38c 100644 --- a/surya/postprocessing/heatmap.py +++ b/surya/postprocessing/heatmap.py @@ -69,13 +69,6 @@ def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg= return text_threshold, low_text -def fast_contours_cumsum(segmap): - # Nonzero is slow, so use this, then mod and div to get x, y - # x and y are flipped in the output because openCV uses (y, x) instead of (x, y) - flat_indices = np.flatnonzero(segmap) - return np.column_stack((flat_indices % segmap.shape[1], flat_indices // segmap.shape[1])) - - def detect_boxes(linemap, text_threshold, low_text): # From CRAFT - https://github.com/clovaai/CRAFT-pytorch # Modified to return boxes and for speed, accuracy @@ -108,22 +101,28 @@ def detect_boxes(linemap, text_threshold, low_text): x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]] try: - niter = int(np.sqrt(size * min(w, h) / (w * h)) * 2) + niter = int(np.sqrt(min(w, h)) * 2) except ValueError: - # Overflow when size is too large + # Overflow in sqrt term niter = 0 + buffer = 1 sx, sy = max(0, x - niter), max(0, y - niter) - ex, ey = min(img_w, x + w + niter + 1), min(img_h, y + h + niter + 1) + ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer) segmap.fill(0) - segmap[mask] = 255 + segmap[mask] = 1 + + ksize = buffer + niter + kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize)) - kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) - segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) + # Doesn't work well without the zero start (ie, you can't trim the map tightly around the detected region) + selected_segmap = segmap[0:ey, 0:ex] + selected_segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) # make box - np_contours = fast_contours_cumsum(segmap) + indices = np.nonzero(selected_segmap) + np_contours = np.column_stack((indices[1], indices[0])) rectangle = cv2.minAreaRect(np_contours) box = cv2.boxPoints(rectangle) @@ -148,7 +147,7 @@ def detect_boxes(linemap, text_threshold, low_text): if max_confidence > 0: confidences = [c / max_confidence for c in confidences] - return det, labels, confidences + return det, confidences def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: @@ -160,7 +159,7 @@ def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[Pol textmap = textmap.copy() textmap = textmap.astype(np.float32) - boxes, labels, confidences = detect_boxes(textmap, text_threshold, low_text) + boxes, confidences = detect_boxes(textmap, text_threshold, low_text) # From point form to box form boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)] return boxes diff --git a/surya/recognition.py b/surya/recognition.py index 883ebfe..8fcd56c 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -2,7 +2,6 @@ import torch from PIL import Image -from surya.input.processing import convert_if_not_rgb from surya.postprocessing.math.latex import fix_math, contains_math from surya.postprocessing.text import truncate_repetitions from surya.settings import settings @@ -18,69 +17,61 @@ def get_batch_size(): if settings.TORCH_DEVICE_MODEL == "mps": batch_size = 64 # 12GB RAM max if settings.TORCH_DEVICE_MODEL == "cuda": - batch_size = 256 + batch_size = 512 return batch_size -def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None): +def pad_to_batch_size(tensor, batch_size): + current_batch_size = tensor.shape[0] + if current_batch_size >= batch_size: + return tensor + + pad_size = batch_size - current_batch_size + padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) + + return F.pad(tensor, padding, mode='constant', value=0) + + +def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None): assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(languages) if len(images) == 0: return [], [] - for l in languages: - assert len(l) <= settings.RECOGNITION_MAX_LANGS, f"OCR only supports up to {settings.RECOGNITION_MAX_LANGS} languages per image, you passed {l}." - - images = [image.convert("RGB") for image in images] # also copies the images if batch_size is None: batch_size = get_batch_size() + # Sort images by width, so similar length ones go together + sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) + indices, images = zip(*sorted_pairs) + indices = list(indices) + images = list(images) + output_text = [] confidences = [] - - dec_config = model.config.decoder - layer_count = dec_config.decoder_layers - kv_heads = dec_config.kv_heads - head_dim = int(dec_config.d_model / dec_config.decoder_attention_heads) - min_val = torch.finfo(model.dtype).min - - if settings.RECOGNITION_STATIC_CACHE: - # We'll re-use these for all batches to avoid recopying - kv_mask = torch.full((batch_size, 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device) - # The +1 accounts for start token - initial_attn_mask = torch.full((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), min_val, dtype=model.dtype, device=model.device) - - # Decoder kv cache - # 7 (layers) x 2 (kv) x bs x 4 (heads) x max tokens x 64 (head dim) - decoder_cache = [torch.zeros((2, batch_size, kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device) for _ in range(layer_count)] - - # Prefill - decoder_input = torch.zeros((batch_size, settings.RECOGNITION_MAX_LANGS + 1), dtype=torch.long, device=model.device) - else: - initial_kv_mask = torch.zeros((batch_size, 1, 1, 1), dtype=model.dtype, device=model.device) - initial_attn_mask = torch.zeros((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), dtype=model.dtype, device=model.device) - - processed_batches = processor(text=[""] * len(images), images=images, lang=languages) - for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): - batch_langs = languages[i:i+batch_size] - has_math = ["_math" in lang for lang in batch_langs] + batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images - batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size] - batch_langs = processed_batches["langs"][i:i+batch_size] - max_lang_len = max([len(lang) for lang in batch_langs]) + batch_langs = languages[i:i+batch_size] + has_math = [lang and "_math" in lang for lang in batch_langs] - # Pad languages to max length if needed, to ensure we can convert to a tensor - for lang_idx in range(len(batch_langs)): - lang_len = len(batch_langs[lang_idx]) - if lang_len < max_lang_len: - batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx] + processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) + batch_pixel_values = processed_batch["pixel_values"] + batch_langs = processed_batch["langs"] batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] + max_input_length = max([len(tokens) for tokens in batch_decoder_input]) + + # Pad decoder input to max length if needed, to ensure we can convert to a tensor + for token_idx in range(len(batch_decoder_input)): + lang_len = len(batch_decoder_input[token_idx]) + if lang_len < max_input_length: + batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx] + current_batch_size = len(batch_pixel_values) - batch_langs = torch.tensor(np.stack(batch_langs, axis=0), dtype=torch.long, device=model.device) batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) @@ -88,112 +79,76 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor inference_token_count = batch_decoder_input.shape[-1] batch_predictions = [[] for _ in range(current_batch_size)] - decoder_input_pad = torch.zeros((batch_size - current_batch_size, 1), dtype=torch.long, device=model.device) - - if settings.RECOGNITION_STATIC_CACHE: - # Reset shared tensors - if i > 0: - # Decoder cache - for layer_cache in decoder_cache: - layer_cache.fill_(0) - - # KV mask - kv_mask.fill_(min_val) - kv_mask[:, :, :, -1] = 0 - kv_mask[:, :, :, :inference_token_count] = 0 - - # Attention mask - initial_attn_mask.fill_(min_val) + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) + model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) - # Prefill - decoder_input.fill_(0) - - # Prefill attention mask - attention_mask = initial_attn_mask - attention_mask[:, :, -inference_token_count:, -inference_token_count:] = 0 - - # Prefill input - decoder_input[:current_batch_size, -inference_token_count:] = batch_decoder_input - batch_decoder_input = decoder_input - - # Pad to max batch size - batch_langs = torch.cat([batch_langs, torch.zeros((batch_size - current_batch_size, batch_langs.shape[-1]), dtype=torch.long, device=model.device)], dim=0) - batch_pixel_values = torch.cat([batch_pixel_values, torch.zeros((batch_size - current_batch_size,) + batch_pixel_values.shape[1:], dtype=model.dtype, device=model.device)], dim=0) - else: - # Select seed attention mask - kv_mask = initial_kv_mask[:current_batch_size] - kv_mask.fill_(0) - - # Select prefill attention mask - attention_mask = initial_attn_mask[:current_batch_size, :, :inference_token_count, :inference_token_count] - - decoder_cache = [None] * layer_count - - encoder_outputs = None sequence_scores = None - encoder_cache = [None] * layer_count all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) + encoder_hidden_states = None with torch.no_grad(): # inference_mode doesn't work with torch.compile - # Run post-prefill tokens - while token_count < settings.RECOGNITION_MAX_TOKENS: + encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1 + for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): + encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] + encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state + if encoder_hidden_states is None: + encoder_hidden_states = encoder_hidden_states_batch + else: + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) + + text_encoder_input_ids = torch.arange( + model.text_encoder.config.query_token_count, + device=encoder_hidden_states.device, + dtype=torch.long + ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) + + encoder_text_hidden_states = model.text_encoder( + input_ids=text_encoder_input_ids, + cache_position=None, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + use_cache=False + ).hidden_states + del encoder_hidden_states + + if settings.RECOGNITION_STATIC_CACHE: + # Pad inputs to max batch size for static cache + encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size) + batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) + + while token_count < settings.RECOGNITION_MAX_TOKENS - 1: is_prefill = token_count == 0 - return_dict = model( - decoder_input_ids=batch_decoder_input, - decoder_attention_mask=attention_mask, - decoder_self_kv_cache=None if is_prefill else decoder_cache, - decoder_cross_kv_cache=None if is_prefill else encoder_cache, - decoder_past_token_count=token_count, - decoder_langs=batch_langs, - pixel_values=batch_pixel_values, - encoder_outputs=encoder_outputs, - return_dict=True, + #TODO: add attention mask + return_dict = model.decoder( + input_ids=batch_decoder_input, + encoder_hidden_states=encoder_text_hidden_states, + cache_position=decoder_position_ids, + use_cache=True, + prefill=is_prefill ) + decoder_position_ids = decoder_position_ids[-1:] + 1 logits = return_dict["logits"][:current_batch_size] # Ignore batch padding + aux_logits = return_dict.get("aux_logits", None) + preds = torch.argmax(logits[:, -1], dim=-1) - scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values + scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) done = done all_done = all_done | done - scores[all_done == 1] = 0 - if is_prefill: sequence_scores = scores - encoder_outputs = (return_dict["encoder_last_hidden_state"],) else: + scores = scores.masked_fill(all_done, 0) sequence_scores = torch.cat([sequence_scores, scores], dim=1) if all_done.all(): break - past_key_values = return_dict["past_key_values"] - token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device) - - for layer_idx, layer in enumerate(past_key_values): - if is_prefill: - encoder_cache[layer_idx] = layer[1] - - if settings.RECOGNITION_STATIC_CACHE: - # Fill in entries in static kv cache - decoder_cache[layer_idx][:, :, :, token_range, :] = layer[0][:, :, :, -inference_token_count:, :] - else: - # Cat to generate new kv cache including current tokens - if is_prefill: - decoder_cache[layer_idx] = layer[0] - else: - decoder_cache[layer_idx] = torch.cat([decoder_cache[layer_idx], layer[0]], dim=3) - batch_decoder_input = preds.unsqueeze(1) - if settings.RECOGNITION_STATIC_CACHE: - # Setup new attention mask and input token - kv_mask[:, :, :, token_count:(token_count + inference_token_count)] = 0 - batch_decoder_input = torch.cat([batch_decoder_input, decoder_input_pad], dim=0) # Pad to full batch - else: - kv_mask = torch.cat([kv_mask, torch.zeros((current_batch_size, 1, 1, inference_token_count), dtype=model.dtype, device=model.device)], dim=-1) - - attention_mask = kv_mask for j, (pred, status) in enumerate(zip(preds, all_done)): if not status: @@ -201,6 +156,11 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor token_count += inference_token_count inference_token_count = batch_decoder_input.shape[-1] + max_position_id = torch.max(decoder_position_ids).item() + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id + + if settings.RECOGNITION_STATIC_CACHE: + batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) detected_text = processor.tokenizer.batch_decode(batch_predictions) @@ -211,6 +171,12 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor output_text.extend(detected_text) confidences.extend(sequence_scores.tolist()) + del encoder_text_hidden_states + + output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) + confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) + output_text = [text for _, text in output_text] + confidences = [conf for _, conf in confidences] return output_text, confidences diff --git a/surya/schema.py b/surya/schema.py index 129f991..88d42cb 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -140,7 +140,7 @@ class TextLine(PolygonBox): class OCRResult(BaseModel): text_lines: List[TextLine] - languages: List[str] + languages: List[str] | None = None image_bbox: List[float] diff --git a/surya/settings.py b/surya/settings.py index 9e47ae8..33c4bdd 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -10,8 +10,9 @@ class Settings(BaseSettings): # General TORCH_DEVICE: Optional[str] = None - IMAGE_DPI: int = 96 + IMAGE_DPI: int = 192 IN_STREAMLIT: bool = False # Whether we're running in streamlit + ENABLE_EFFICIENT_ATTENTION: bool = True # Usually keep True, but if you get CUDA errors, setting to False can help # Paths DATA_DIR: str = "data" @@ -43,10 +44,10 @@ def TORCH_DEVICE_MODEL(self) -> str: DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize # Text recognition - RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec" + RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec2" RECOGNITION_MAX_TOKENS: int = 175 RECOGNITION_BATCH_SIZE: Optional[int] = None # Defaults to 8 for CPU/MPS, 256 otherwise - RECOGNITION_IMAGE_SIZE: Dict = {"height": 196, "width": 896} + RECOGNITION_IMAGE_SIZE: Dict = {"height": 256, "width": 896} RECOGNITION_RENDER_FONTS: Dict[str, str] = { "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), @@ -57,7 +58,7 @@ def TORCH_DEVICE_MODEL(self) -> str: RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile - RECOGNITION_MAX_LANGS: int = 4 + RECOGNITION_ENCODER_BATCH_DIVISOR: int = 2 # Divisor for batch size in decoder # Layout LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout3"