Skip to content

Commit

Permalink
Add table parsing script
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 3, 2024
1 parent b3fe5ae commit 3e2b86c
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 136 deletions.
4 changes: 2 additions & 2 deletions detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def main():
det_processor = load_processor()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

line_predictions = batch_text_detection(images, det_model, det_processor)
Expand Down
4 changes: 2 additions & 2 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def main():
processor = load_processor(checkpoint=checkpoint)

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
Expand Down
7 changes: 3 additions & 4 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ def order_detection(img) -> (Image.Image, OrderResult):

def table_recognition(img, filepath, page_idx: int, use_pdf_boxes: bool) -> (Image.Image, List[TableResult]):
_, layout_pred = layout_detection(img)
layout_tables = [l for l in layout_pred.bboxes if l.label == "Table"]
layout_tables_bboxes = [l.bbox for l in layout_tables]
layout_tables = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]

table_imgs = []
for table_bbox in layout_tables_bboxes:
for table_bbox in layout_tables:
table_imgs.append(img.crop(table_bbox))
if use_pdf_boxes:
page_text = get_page_text_lines(filepath, page_idx, img.size)
Expand All @@ -93,7 +92,7 @@ def table_recognition(img, filepath, page_idx: int, use_pdf_boxes: bool) -> (Ima
table_bboxes = [[tb.bbox for tb in table_box.bboxes] for table_box in table_boxes]
table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
table_img = img.copy()
for results, table_bbox in zip(table_preds, layout_tables_bboxes):
for results, table_bbox in zip(table_preds, layout_tables):
adjusted_bboxes = []
labels = []
for item in results.cells:
Expand Down
4 changes: 2 additions & 2 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def main():
args = parser.parse_args()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max, args.start_page)
images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max, args.start_page)
images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
folder_name = os.path.basename(args.input_path).split(".")[0]

if args.lang_file:
Expand Down
4 changes: 2 additions & 2 deletions reading_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def main():
det_processor = load_det_processor()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

line_predictions = batch_text_detection(images, det_model, det_processor)
Expand Down
19 changes: 14 additions & 5 deletions surya/input/load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import PIL

from surya.input.pdflines import get_page_text_lines
from surya.input.processing import open_pdf, get_page_images
import os
import filetype
Expand All @@ -26,15 +27,20 @@ def load_pdf(pdf_path, max_pages=None, start_page=None):

page_indices = list(range(start_page, last_page))
images = get_page_images(doc, page_indices)
text_lines = get_page_text_lines(
pdf_path,
page_indices,
[i.size for i in images]
)
doc.close()
names = [get_name_from_path(pdf_path) for _ in page_indices]
return images, names
return images, names, text_lines


def load_image(image_path):
image = Image.open(image_path).convert("RGB")
name = get_name_from_path(image_path)
return [image], [name]
return [image], [name], [None]


def load_from_file(input_path, max_pages=None, start_page=None):
Expand All @@ -51,21 +57,24 @@ def load_from_folder(folder_path, max_pages=None, start_page=None):

images = []
names = []
text_lines = []
for path in image_paths:
extension = filetype.guess(path)
if extension and extension.extension == "pdf":
image, name = load_pdf(path, max_pages, start_page)
image, name, text_line = load_pdf(path, max_pages, start_page)
images.extend(image)
names.extend(name)
text_lines.extend(text_line)
else:
try:
image, name = load_image(path)
image, name, text_line = load_image(path)
images.extend(image)
names.extend(name)
text_lines.extend(text_line)
except PIL.UnidentifiedImageError:
print(f"Could not load image {path}")
continue
return images, names
return images, names, text_lines


def load_lang_file(lang_path, names):
Expand Down
50 changes: 29 additions & 21 deletions surya/input/pdflines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,34 @@
from surya.schema import PolygonBox


def get_page_text_lines(filepath, page_idx, out_size):
full_text = dictionary_output(filepath, sort=False, page_range=[page_idx], keep_chars=True)[0]
text_bbox = full_text["bbox"]
text_w_scale = out_size[0] / text_bbox[2]
text_h_scale = out_size[1] / text_bbox[3]
for block in full_text["blocks"]:
for line in block["lines"]:
line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
for span in line["spans"]:
for char in span["chars"]:
char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
return full_text
def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list):
assert len(page_idxs) == len(out_sizes)
pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True)
for full_text, out_size in zip(pages_text, out_sizes):
text_bbox = full_text["bbox"]
text_w_scale = out_size[0] / text_bbox[2]
text_h_scale = out_size[1] / text_bbox[3]
for block in full_text["blocks"]:
for line in block["lines"]:
line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
for span in line["spans"]:
for char in span["chars"]:
char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
return pages_text


def get_table_blocks(tables, full_text, img_size):
def get_table_blocks(tables: list, full_text: list, img_size: list, table_thresh=.8):
# Returns coordinates relative to input table, not full image
table_texts = []
for table in tables:
table_poly = PolygonBox(polygon=[
[table[0], table[1]],
[table[2], table[1]],
[table[2], table[3]],
[table[0], table[3]]
])
table_text = []
for block in full_text["blocks"]:
for line in block["lines"]:
Expand All @@ -33,7 +41,7 @@ def get_table_blocks(tables, full_text, img_size):
[line["bbox"][2], line["bbox"][3]],
[line["bbox"][0], line["bbox"][3]]
])
if line_poly.intersection_pct(table) < 0.8:
if line_poly.intersection_pct(table_poly) < table_thresh:
continue
curr_span = None
curr_box = None
Expand All @@ -42,7 +50,7 @@ def get_table_blocks(tables, full_text, img_size):
if curr_span is None:
curr_span = char["char"]
curr_box = char["bbox"]
elif (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01:
elif (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01:
curr_span += char["char"]
curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
Expand All @@ -55,10 +63,10 @@ def get_table_blocks(tables, full_text, img_size):
# Adjust to be relative to input table
for item in table_text:
item["bbox"] = [
item["bbox"][0] - table.bbox[0],
item["bbox"][1] - table.bbox[1],
item["bbox"][2] - table.bbox[0],
item["bbox"][3] - table.bbox[1]
item["bbox"][0] - table[0],
item["bbox"][1] - table[1],
item["bbox"][2] - table[0],
item["bbox"][3] - table[1]
]
table_text = sort_text_lines(table_text)
table_texts.append(table_text)
Expand Down
10 changes: 8 additions & 2 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
logits = np.stack(heatmaps, axis=0)
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
vertical_line_bboxes = detection_result.vertical_lines
line_bboxes = detection_result.bboxes

# Scale back to processor size
Expand All @@ -38,6 +38,8 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
detected_boxes = []
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
heatmap = logits[heatmap_idx]
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
continue
bboxes = get_detected_boxes(heatmap)
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
for bb in bboxes:
Expand Down Expand Up @@ -150,10 +152,14 @@ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignm
heatmap = heatmaps[i]
assert heatmap.shape == segment_assignment.shape
heatmap[segment_assignment != i] = 0 # zero out where another segment is

# Skip processing empty labels
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
continue

bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
for bb in bbox:
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
heatmaps.append(heatmap)

bboxes = keep_largest_boxes(bboxes)
return bboxes
Expand Down
6 changes: 4 additions & 2 deletions surya/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ class OrderResult(BaseModel):


class TableCell(Bbox):
row_id: int
col_id: int
row_id: int | None = None
col_id: int | None = None


class TableResult(BaseModel):
cells: List[TableCell]
rows: List[TableCell]
cols: List[TableCell]
image_bbox: List[float]
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

# Table Rec
TABLE_REC_MODEL_CHECKPOINT: str = "vikp/table_rec_ar2"
TABLE_REC_MODEL_CHECKPOINT: str = "vikp/table_rec_ar3"
TABLE_REC_IMAGE_SIZE: Dict = {"height": 640, "width": 640}
TABLE_REC_MAX_BOXES: int = 512
TABLE_REC_MAX_ROWS: int = 384
Expand Down
Loading

0 comments on commit 3e2b86c

Please sign in to comment.