From 4fa6ff60d6dbd483e63e883c114f02d19978a1ba Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Fri, 11 Oct 2024 17:06:54 -0400 Subject: [PATCH] Refactor to move cell assignment out of library --- CLA.md | 4 +- README.md | 2 - ocr_app.py | 11 ++++-- pyproject.toml | 2 +- surya/input/pdflines.py | 65 +++++++++++++++++++++++++++----- surya/schema.py | 22 +++++++++-- surya/tables.py | 83 +++-------------------------------------- table_recognition.py | 5 --- 8 files changed, 90 insertions(+), 104 deletions(-) diff --git a/CLA.md b/CLA.md index 93a0174..cbea049 100644 --- a/CLA.md +++ b/CLA.md @@ -1,6 +1,6 @@ Surya Contributor Agreement -This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Vikas Paruchuri. The term "you" shall mean the person or entity identified below. +This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. @@ -20,5 +20,5 @@ If you or your affiliates institute patent litigation against any entity (includ - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. -You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Vikas Paruchuri may publicly disclose your participation in the project, including the fact that you have signed the SCA. +You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. \ No newline at end of file diff --git a/README.md b/README.md index defbcba..c9de236 100644 --- a/README.md +++ b/README.md @@ -298,8 +298,6 @@ The `results.json` file will contain a json dictionary where the keys are the in - `cells` - detected table cells - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - - `row_id` - the id of the row this cell belongs to. - - `col_id` - the id of the column this cell belongs to. - `text` - if text could be pulled out of the pdf, the text of this cell. - `rows` - detected table rows - `bbox` - the bounding box of the table row diff --git a/ocr_app.py b/ocr_app.py index cb6791d..71aa64d 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -112,16 +112,21 @@ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: for results, table_bbox in zip(table_preds, layout_tables): adjusted_bboxes = [] labels = [] + colors = [] - for item in results.cells: + for item in results.rows + results.cols: adjusted_bboxes.append([ (item.bbox[0] + table_bbox[0]), (item.bbox[1] + table_bbox[1]), (item.bbox[2] + table_bbox[0]), (item.bbox[3] + table_bbox[1]) ]) - labels.append(f"{item.row_id} / {item.col_id}") - table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18) + labels.append(item.label) + if hasattr(item, "row_id"): + colors.append("blue") + else: + colors.append("red") + table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors) return table_img, table_preds diff --git a/pyproject.toml b/pyproject.toml index dc5105c..ab76370 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.6.1" +version = "0.6.2" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/input/pdflines.py b/surya/input/pdflines.py index 8f36aef..864bc7f 100644 --- a/surya/input/pdflines.py +++ b/surya/input/pdflines.py @@ -2,6 +2,7 @@ from surya.postprocessing.text import sort_text_lines from surya.schema import PolygonBox +import numpy as np def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list: @@ -23,9 +24,62 @@ def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list return pages_text -def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8): +def get_dynamic_gap_thresh(full_text: dict, img_size: list, default_thresh=.01, min_chars=100): + space_dists = [] + for block in full_text["blocks"]: + for line in block["lines"]: + for span in line["spans"]: + for i in range(1, len(span["chars"])): + char1 = span["chars"][i - 1] + char2 = span["chars"][i] + if full_text["rotation"] == 90: + space_dists.append((char2["bbox"][0] - char1["bbox"][2]) / img_size[0]) + elif full_text["rotation"] == 180: + space_dists.append((char2["bbox"][1] - char1["bbox"][3]) / img_size[1]) + elif full_text["rotation"] == 270: + space_dists.append((char1["bbox"][0] - char2["bbox"][2]) / img_size[0]) + else: + space_dists.append((char1["bbox"][1] - char2["bbox"][3]) / img_size[1]) + cell_gap_thresh = np.percentile(space_dists, 80) if len(space_dists) > min_chars else default_thresh + return cell_gap_thresh + + +def is_same_span(char, curr_box, img_size, space_thresh, rotation): + def normalized_diff(a, b, dimension, mult=1, use_abs=True): + func = abs if use_abs else lambda x: x + return func(a - b) / img_size[dimension] < space_thresh * mult + + bbox = char["bbox"] + if rotation == 90: + return all([ + normalized_diff(bbox[0], curr_box[0], 0, use_abs=False), + normalized_diff(bbox[1], curr_box[3], 1), + normalized_diff(bbox[0], curr_box[0], 0, mult=5) + ]) + elif rotation == 180: + return all([ + normalized_diff(bbox[2], curr_box[0], 0, use_abs=False), + normalized_diff(bbox[1], curr_box[1], 1), + normalized_diff(bbox[2], curr_box[0], 1, mult=5) + ]) + elif rotation == 270: + return all([ + normalized_diff(bbox[0], curr_box[0], 0, use_abs=False), + normalized_diff(bbox[3], curr_box[1], 1), + normalized_diff(bbox[0], curr_box[0], 1, mult=5) + ]) + else: # 0 or default case + return all([ + normalized_diff(bbox[0], curr_box[2], 0, use_abs=False), + normalized_diff(bbox[1], curr_box[1], 1), + normalized_diff(bbox[0], curr_box[2], 1, mult=5) + ]) + + +def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8, space_thresh=.01): # Returns coordinates relative to input table, not full image table_texts = [] + space_thresh = max(space_thresh, get_dynamic_gap_thresh(full_text, img_size, default_thresh=space_thresh)) for table in tables: table_poly = PolygonBox(polygon=[ [table[0], table[1]], @@ -51,14 +105,7 @@ def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh for char in span["chars"]: same_span = False if curr_span: - if rotation == 90: - same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01 - elif rotation == 180: - same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01 - elif rotation == 270: - same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01 - else: - same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01 + same_span = is_same_span(char, curr_box, img_size, space_thresh, rotation) if curr_span is None: curr_span = char["char"] diff --git a/surya/schema.py b/surya/schema.py index 880a77e..1f41438 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -180,13 +180,27 @@ class OrderResult(BaseModel): class TableCell(Bbox): - row_id: int | None = None - col_id: int | None = None text: str | None = None +class TableRow(Bbox): + row_id: int + + @property + def label(self): + return f'Row {self.row_id}' + + +class TableCol(Bbox): + col_id: int + + @property + def label(self): + return f'Column {self.col_id}' + + class TableResult(BaseModel): cells: List[TableCell] - rows: List[TableCell] - cols: List[TableCell] + rows: List[TableRow] + cols: List[TableCol] image_bbox: List[float] diff --git a/surya/tables.py b/surya/tables.py index 9e53344..01c7ead 100644 --- a/surya/tables.py +++ b/surya/tables.py @@ -5,7 +5,7 @@ from PIL import Image from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel -from surya.schema import TableResult, TableCell, Bbox +from surya.schema import TableResult, TableCell, Bbox, TableCol, TableRow from surya.settings import settings from tqdm import tqdm import numpy as np @@ -40,19 +40,6 @@ def sort_bboxes(bboxes, tolerance=1): return sorted_page_blocks -def is_rotated(rows, cols): - # Determine if the table is rotated by looking at row and column width / height ratios - # Rows should have a >1 ratio, cols <1 - widths = sum([r.width for r in rows]) - heights = sum([c.height for c in rows]) + 1 - r_ratio = widths / heights - - widths = sum([c.width for c in cols]) - heights = sum([r.height for r in cols]) + 1 - c_ratio = widths / heights - - return r_ratio * 2 < c_ratio - def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(table_cells) @@ -166,87 +153,27 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: rows = [] cols = [] for row_idx, row in enumerate(bb_rows): - cell = TableCell( + rows.append(TableRow( bbox=row, row_id=row_idx - ) - rows.append(cell) + )) for col_idx, col in enumerate(bb_cols): - cell = TableCell( + cols.append(TableCol( bbox=col, col_id=col_idx, - ) - cols.append(cell) + )) # Assign cells to rows/columns cells = [] for cell in input_cells: - max_intersection = 0 - row_pred = None - for row_idx, row in enumerate(rows): - intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row) - if intersection_pct > max_intersection: - max_intersection = intersection_pct - row_pred = row_idx - - max_intersection = 0 - col_pred = None - for col_idx, col in enumerate(cols): - intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col) - if intersection_pct > max_intersection: - max_intersection = intersection_pct - col_pred = col_idx - cells.append( TableCell( bbox=cell["bbox"], text=cell.get("text"), - row_id=row_pred, - col_id=col_pred ) ) - rotated = is_rotated(rows, cols) - for cell in cells: - if cell.row_id is None: - closest_row = None - closest_row_dist = None - for cell2 in cells: - if cell2.row_id is None: - continue - if rotated: - cell_y_center = cell.center[0] - cell2_y_center = cell2.center[0] - else: - cell_y_center = cell.center[1] - cell2_y_center = cell2.center[1] - y_dist = abs(cell_y_center - cell2_y_center) - if closest_row_dist is None or y_dist < closest_row_dist: - closest_row = cell2.row_id - closest_row_dist = y_dist - cell.row_id = closest_row - - if cell.col_id is None: - closest_col = None - closest_col_dist = None - for cell2 in cells: - if cell2.col_id is None: - continue - if rotated: - cell_x_center = cell.center[1] - cell2_x_center = cell2.center[1] - else: - cell_x_center = cell.center[0] - cell2_x_center = cell2.center[0] - - x_dist = abs(cell2_x_center - cell_x_center) - if closest_col_dist is None or x_dist < closest_col_dist: - closest_col = cell2.col_id - closest_col_dist = x_dist - - cell.col_id = closest_col - result = TableResult( cells=cells, rows=rows, diff --git a/table_recognition.py b/table_recognition.py index 4e03335..013500e 100644 --- a/table_recognition.py +++ b/table_recognition.py @@ -121,11 +121,6 @@ def main(): table_predictions[orig_name].append(out_pred) if args.images: - boxes = [l.bbox for l in pred.cells] - labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells] - bbox_image = draw_bboxes_on_image(boxes, copy.deepcopy(table_img), labels=labels, label_font_size=20) - bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png")) - rows = [l.bbox for l in pred.rows] cols = [l.bbox for l in pred.cols] row_labels = [f"Row {l.row_id}" for l in pred.rows]