Skip to content

Commit

Permalink
Merge pull request VikParuchuri#195 from VikParuchuri/tr2
Browse files Browse the repository at this point in the history
Tr2
  • Loading branch information
VikParuchuri authored Oct 4, 2024
2 parents 4ffc9bb + 50a8589 commit 663e11c
Show file tree
Hide file tree
Showing 27 changed files with 2,577 additions and 80 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ jobs:
run: |
poetry run python benchmark/layout.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark text
- name: Run ordering benchmark
run: |
poetry run python benchmark/ordering.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
- name: Run table recognition benchmark
run: |
poetry run python benchmark/table_recognition.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
66 changes: 61 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Surya is a document OCR toolkit that does:
- Line-level text detection in any language
- Layout analysis (table, image, header, etc detection)
- Reading order detection
- Table recognition (detecting rows/columns)

It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).

Expand Down Expand Up @@ -272,6 +273,43 @@ processor = load_processor()
order_predictions = batch_ordering([image], [bboxes], model, processor)
```

## Table Recognition

This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.

```shell
surya_table DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible.
- `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `cells` - detected table cells
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `row_id` - the id of the row this cell belongs to.
- `col_id` - the id of the column this cell belongs to.
- `text` - if text could be pulled out of the pdf, the text of this cell.
- `rows` - detected table rows
- `bbox` - the bounding box of the table row
- `row_id` - the id of the row
- `cols` - detected table columns
- `bbox` - the bounding box of the table column
- `col_id`- the id of the column
- `page` - the page number in the file
- `table_idx` - the index of the table on the page (sorted in vertical order)
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.


# Limitations

- This is specialized for document OCR. It will likely not work on photos or other images.
Expand Down Expand Up @@ -381,10 +419,18 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/

**Methodology**

I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.

The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.

## Table Recognition

.93 penalized row iou (out of 1), and .86 penalized column iou. Took .05 seconds per image on an A10.

**Methodology**

The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns.

## Running your own benchmarks

You can benchmark the performance of surya on your machine.
Expand All @@ -396,7 +442,7 @@ You can benchmark the performance of surya on your machine.

This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).

```
```shell
python benchmark/detection.py --max 256
```

Expand All @@ -409,7 +455,7 @@ python benchmark/detection.py --max 256

This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).

```
```shell
python benchmark/recognition.py --tesseract
```

Expand All @@ -425,7 +471,7 @@ python benchmark/recognition.py --tesseract

This will evaluate surya on the publaynet dataset.

```
```shell
python benchmark/layout.py
```

Expand All @@ -435,14 +481,24 @@ python benchmark/layout.py

**Reading Order**

```
```shell
python benchmark/ordering.py
```

- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Table Recognition**

```shell
python benchmark/table_recognition.py
```

- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

# Training

Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.
Expand Down
88 changes: 88 additions & 0 deletions benchmark/table_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import collections
import copy
import json

from surya.input.processing import convert_if_not_rgb
from surya.model.table_rec.model import load_model
from surya.model.table_rec.processor import load_processor
from surya.tables import batch_table_recognition
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy, penalized_iou_score
import os
import time
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya table recognition model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()

model = load_model()
processor = load_processor()

pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
bboxes = list(dataset["bboxes"])

start = time.time()
bboxes = [[{"bbox": b, "text": None} for b in bb] for bb in bboxes]
table_rec_predictions = batch_table_recognition(images, bboxes, model, processor)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, pred in enumerate(table_rec_predictions):
row = dataset[idx]
pred_row_boxes = [p.bbox for p in pred.rows]
pred_col_bboxes = [p.bbox for p in pred.cols]
actual_row_bboxes = row["rows"]
actual_col_bboxes = row["cols"]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes)
}

mean_col_iou += col_score
mean_row_iou += row_score

page_metrics[idx] = page_results

mean_col_iou /= len(table_rec_predictions)
mean_row_iou /= len(table_rec_predictions)

out_data = {
"time": surya_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics
}

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(out_data, f, indent=4)

print(f"Mean penalized row iou is {mean_row_iou:.2f}. Mean penalized column iou is {mean_col_iou:.2f}.")
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print("Mean iou is the average of the iou scores for each row or column, with penalties for too many/few predictions.")
print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def main():
det_processor = load_processor()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

line_predictions = batch_text_detection(images, det_model, det_processor)
Expand Down
4 changes: 2 additions & 2 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def main():
processor = load_processor(checkpoint=checkpoint)

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
Expand Down
65 changes: 60 additions & 5 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@
import pypdfium2
import streamlit as st
from surya.detection import batch_text_detection
from surya.input.pdflines import get_page_text_lines, get_table_blocks
from surya.layout import batch_layout_detection
from surya.model.detection.model import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition


@st.cache_resource()
def load_det_cached():
Expand All @@ -40,6 +45,11 @@ def load_order_cached():
return load_order_model(), load_order_processor()


@st.cache_resource()
def load_table_cached():
return load_table_model(), load_table_processor()


def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = batch_text_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
Expand All @@ -52,7 +62,7 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels)
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
return layout_img, pred


Expand All @@ -62,10 +72,43 @@ def order_detection(img) -> (Image.Image, OrderResult):
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
return order_img, pred


def table_recognition(img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
if skip_table_detection:
layout_tables = [(0, 0, img.size[0], img.size[1])]
table_imgs = [img]
else:
_, layout_pred = layout_detection(img)
layout_tables = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
table_imgs = [img.crop(tb) for tb in layout_tables]

if use_pdf_boxes:
page_text = get_page_text_lines(filepath, [page_idx], [img.size])[0]
table_bboxes = get_table_blocks(layout_tables, page_text, img.size)
else:
det_results = batch_text_detection(table_imgs, det_model, det_processor)
table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
table_img = img.copy()

for results, table_bbox in zip(table_preds, layout_tables):
adjusted_bboxes = []
labels = []
for item in results.cells:
adjusted_bboxes.append([
item.bbox[0] + table_bbox[0],
item.bbox[1] + table_bbox[1],
item.bbox[2] + table_bbox[0],
item.bbox[3] + table_bbox[1]
])
labels.append(f"{item.row_id} / {item.col_id}")
table_img = draw_bboxes_on_image(adjusted_bboxes, table_img, labels=labels, label_font_size=18)
return table_img, table_preds


# Function for OCR
def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
replace_lang_with_code(langs)
Expand All @@ -83,7 +126,7 @@ def open_pdf(pdf_file):


@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=96):
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
doc = open_pdf(pdf_file)
renderer = doc.render(
pypdfium2.PdfBitmap.to_pil,
Expand All @@ -108,6 +151,7 @@ def page_count(pdf_file):
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
table_model, table_processor = load_table_cached()


st.markdown("""
Expand Down Expand Up @@ -139,11 +183,15 @@ def page_count(pdf_file):
pil_image = get_page_image(in_file, page_number)
else:
pil_image = Image.open(in_file).convert("RGB")
page_number = None

text_det = st.sidebar.button("Run Text Detection")
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
order_det = st.sidebar.button("Run Reading Order")
table_rec = st.sidebar.button("Run Table Rec")
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")

if pil_image is None:
st.stop()
Expand Down Expand Up @@ -180,5 +228,12 @@ def page_count(pdf_file):
st.image(order_img, caption="Reading Order", use_column_width=True)
st.json(pred.model_dump(), expanded=True)


if table_rec:
table_img, pred = table_recognition(pil_image, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
with col1:
st.image(table_img, caption="Table Recognition", use_column_width=True)
st.json([p.model_dump() for p in pred], expanded=True)

with col2:
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
Loading

0 comments on commit 663e11c

Please sign in to comment.