Skip to content

Commit

Permalink
Refactor to move cell assignment out of library
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 11, 2024
1 parent 7af11c1 commit 4fa6ff6
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 104 deletions.
4 changes: 2 additions & 2 deletions CLA.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Surya Contributor Agreement

This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Vikas Paruchuri. The term "you" shall mean the person or entity identified below.
This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below.

If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.

Expand All @@ -20,5 +20,5 @@ If you or your affiliates institute patent litigation against any entity (includ
- each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA;
- to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and
- each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.
You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Vikas Paruchuri may publicly disclose your participation in the project, including the fact that you have signed the SCA.
You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ The `results.json` file will contain a json dictionary where the keys are the in

- `cells` - detected table cells
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `row_id` - the id of the row this cell belongs to.
- `col_id` - the id of the column this cell belongs to.
- `text` - if text could be pulled out of the pdf, the text of this cell.
- `rows` - detected table rows
- `bbox` - the bounding box of the table row
Expand Down
11 changes: 8 additions & 3 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,21 @@ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes:
for results, table_bbox in zip(table_preds, layout_tables):
adjusted_bboxes = []
labels = []
colors = []

for item in results.cells:
for item in results.rows + results.cols:
adjusted_bboxes.append([
(item.bbox[0] + table_bbox[0]),
(item.bbox[1] + table_bbox[1]),
(item.bbox[2] + table_bbox[0]),
(item.bbox[3] + table_bbox[1])
])
labels.append(f"{item.row_id} / {item.col_id}")
table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18)
labels.append(item.label)
if hasattr(item, "row_id"):
colors.append("blue")
else:
colors.append("red")
table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
return table_img, table_preds


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.1"
version = "0.6.2"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
65 changes: 56 additions & 9 deletions surya/input/pdflines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from surya.postprocessing.text import sort_text_lines
from surya.schema import PolygonBox
import numpy as np


def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list:
Expand All @@ -23,9 +24,62 @@ def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list
return pages_text


def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8):
def get_dynamic_gap_thresh(full_text: dict, img_size: list, default_thresh=.01, min_chars=100):
space_dists = []
for block in full_text["blocks"]:
for line in block["lines"]:
for span in line["spans"]:
for i in range(1, len(span["chars"])):
char1 = span["chars"][i - 1]
char2 = span["chars"][i]
if full_text["rotation"] == 90:
space_dists.append((char2["bbox"][0] - char1["bbox"][2]) / img_size[0])
elif full_text["rotation"] == 180:
space_dists.append((char2["bbox"][1] - char1["bbox"][3]) / img_size[1])
elif full_text["rotation"] == 270:
space_dists.append((char1["bbox"][0] - char2["bbox"][2]) / img_size[0])
else:
space_dists.append((char1["bbox"][1] - char2["bbox"][3]) / img_size[1])
cell_gap_thresh = np.percentile(space_dists, 80) if len(space_dists) > min_chars else default_thresh
return cell_gap_thresh


def is_same_span(char, curr_box, img_size, space_thresh, rotation):
def normalized_diff(a, b, dimension, mult=1, use_abs=True):
func = abs if use_abs else lambda x: x
return func(a - b) / img_size[dimension] < space_thresh * mult

bbox = char["bbox"]
if rotation == 90:
return all([
normalized_diff(bbox[0], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[3], 1),
normalized_diff(bbox[0], curr_box[0], 0, mult=5)
])
elif rotation == 180:
return all([
normalized_diff(bbox[2], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[1], 1),
normalized_diff(bbox[2], curr_box[0], 1, mult=5)
])
elif rotation == 270:
return all([
normalized_diff(bbox[0], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[3], curr_box[1], 1),
normalized_diff(bbox[0], curr_box[0], 1, mult=5)
])
else: # 0 or default case
return all([
normalized_diff(bbox[0], curr_box[2], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[1], 1),
normalized_diff(bbox[0], curr_box[2], 1, mult=5)
])


def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8, space_thresh=.01):
# Returns coordinates relative to input table, not full image
table_texts = []
space_thresh = max(space_thresh, get_dynamic_gap_thresh(full_text, img_size, default_thresh=space_thresh))
for table in tables:
table_poly = PolygonBox(polygon=[
[table[0], table[1]],
Expand All @@ -51,14 +105,7 @@ def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh
for char in span["chars"]:
same_span = False
if curr_span:
if rotation == 90:
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01
elif rotation == 180:
same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
elif rotation == 270:
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01
else:
same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
same_span = is_same_span(char, curr_box, img_size, space_thresh, rotation)

if curr_span is None:
curr_span = char["char"]
Expand Down
22 changes: 18 additions & 4 deletions surya/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,27 @@ class OrderResult(BaseModel):


class TableCell(Bbox):
row_id: int | None = None
col_id: int | None = None
text: str | None = None


class TableRow(Bbox):
row_id: int

@property
def label(self):
return f'Row {self.row_id}'


class TableCol(Bbox):
col_id: int

@property
def label(self):
return f'Column {self.col_id}'


class TableResult(BaseModel):
cells: List[TableCell]
rows: List[TableCell]
cols: List[TableCell]
rows: List[TableRow]
cols: List[TableCol]
image_bbox: List[float]
83 changes: 5 additions & 78 deletions surya/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image

from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.schema import TableResult, TableCell, Bbox
from surya.schema import TableResult, TableCell, Bbox, TableCol, TableRow
from surya.settings import settings
from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -40,19 +40,6 @@ def sort_bboxes(bboxes, tolerance=1):
return sorted_page_blocks


def is_rotated(rows, cols):
# Determine if the table is rotated by looking at row and column width / height ratios
# Rows should have a >1 ratio, cols <1
widths = sum([r.width for r in rows])
heights = sum([c.height for c in rows]) + 1
r_ratio = widths / heights

widths = sum([c.width for c in cols])
heights = sum([r.height for r in cols]) + 1
c_ratio = widths / heights

return r_ratio * 2 < c_ratio

def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(table_cells)
Expand Down Expand Up @@ -166,87 +153,27 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
rows = []
cols = []
for row_idx, row in enumerate(bb_rows):
cell = TableCell(
rows.append(TableRow(
bbox=row,
row_id=row_idx
)
rows.append(cell)
))

for col_idx, col in enumerate(bb_cols):
cell = TableCell(
cols.append(TableCol(
bbox=col,
col_id=col_idx,
)
cols.append(cell)
))

# Assign cells to rows/columns
cells = []
for cell in input_cells:
max_intersection = 0
row_pred = None
for row_idx, row in enumerate(rows):
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row)
if intersection_pct > max_intersection:
max_intersection = intersection_pct
row_pred = row_idx

max_intersection = 0
col_pred = None
for col_idx, col in enumerate(cols):
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col)
if intersection_pct > max_intersection:
max_intersection = intersection_pct
col_pred = col_idx

cells.append(
TableCell(
bbox=cell["bbox"],
text=cell.get("text"),
row_id=row_pred,
col_id=col_pred
)
)

rotated = is_rotated(rows, cols)
for cell in cells:
if cell.row_id is None:
closest_row = None
closest_row_dist = None
for cell2 in cells:
if cell2.row_id is None:
continue
if rotated:
cell_y_center = cell.center[0]
cell2_y_center = cell2.center[0]
else:
cell_y_center = cell.center[1]
cell2_y_center = cell2.center[1]
y_dist = abs(cell_y_center - cell2_y_center)
if closest_row_dist is None or y_dist < closest_row_dist:
closest_row = cell2.row_id
closest_row_dist = y_dist
cell.row_id = closest_row

if cell.col_id is None:
closest_col = None
closest_col_dist = None
for cell2 in cells:
if cell2.col_id is None:
continue
if rotated:
cell_x_center = cell.center[1]
cell2_x_center = cell2.center[1]
else:
cell_x_center = cell.center[0]
cell2_x_center = cell2.center[0]

x_dist = abs(cell2_x_center - cell_x_center)
if closest_col_dist is None or x_dist < closest_col_dist:
closest_col = cell2.col_id
closest_col_dist = x_dist

cell.col_id = closest_col

result = TableResult(
cells=cells,
rows=rows,
Expand Down
5 changes: 0 additions & 5 deletions table_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,6 @@ def main():
table_predictions[orig_name].append(out_pred)

if args.images:
boxes = [l.bbox for l in pred.cells]
labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells]
bbox_image = draw_bboxes_on_image(boxes, copy.deepcopy(table_img), labels=labels, label_font_size=20)
bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png"))

rows = [l.bbox for l in pred.rows]
cols = [l.bbox for l in pred.cols]
row_labels = [f"Row {l.row_id}" for l in pred.rows]
Expand Down

0 comments on commit 4fa6ff6

Please sign in to comment.