From 3e2b86c3ccddfe6905daa9b702f397c064e1523f Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 3 Oct 2024 14:39:55 -0400 Subject: [PATCH] Add table parsing script --- detect_layout.py | 4 +- detect_text.py | 4 +- ocr_app.py | 7 +- ocr_text.py | 4 +- reading_order.py | 4 +- surya/input/load.py | 19 ++++-- surya/input/pdflines.py | 50 ++++++++------ surya/layout.py | 10 ++- surya/schema.py | 6 +- surya/settings.py | 2 +- surya/tables.py | 144 ++++++++++++++++++++++------------------ table_recognition.py | 105 +++++++++++++++++++++-------- 12 files changed, 223 insertions(+), 136 deletions(-) diff --git a/detect_layout.py b/detect_layout.py index 8e791b7..3a54f81 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -27,10 +27,10 @@ def main(): det_processor = load_processor() if os.path.isdir(args.input_path): - images, names = load_from_folder(args.input_path, args.max) + images, names, _ = load_from_folder(args.input_path, args.max) folder_name = os.path.basename(args.input_path) else: - images, names = load_from_file(args.input_path, args.max) + images, names, _ = load_from_file(args.input_path, args.max) folder_name = os.path.basename(args.input_path).split(".")[0] line_predictions = batch_text_detection(images, det_model, det_processor) diff --git a/detect_text.py b/detect_text.py index e2ecc4d..e7b0e5a 100644 --- a/detect_text.py +++ b/detect_text.py @@ -28,10 +28,10 @@ def main(): processor = load_processor(checkpoint=checkpoint) if os.path.isdir(args.input_path): - images, names = load_from_folder(args.input_path, args.max) + images, names, _ = load_from_folder(args.input_path, args.max) folder_name = os.path.basename(args.input_path) else: - images, names = load_from_file(args.input_path, args.max) + images, names, _ = load_from_file(args.input_path, args.max) folder_name = os.path.basename(args.input_path).split(".")[0] start = time.time() diff --git a/ocr_app.py b/ocr_app.py index c8136cc..df7b366 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -78,11 +78,10 @@ def order_detection(img) -> (Image.Image, OrderResult): def table_recognition(img, filepath, page_idx: int, use_pdf_boxes: bool) -> (Image.Image, List[TableResult]): _, layout_pred = layout_detection(img) - layout_tables = [l for l in layout_pred.bboxes if l.label == "Table"] - layout_tables_bboxes = [l.bbox for l in layout_tables] + layout_tables = [l.bbox for l in layout_pred.bboxes if l.label == "Table"] table_imgs = [] - for table_bbox in layout_tables_bboxes: + for table_bbox in layout_tables: table_imgs.append(img.crop(table_bbox)) if use_pdf_boxes: page_text = get_page_text_lines(filepath, page_idx, img.size) @@ -93,7 +92,7 @@ def table_recognition(img, filepath, page_idx: int, use_pdf_boxes: bool) -> (Ima table_bboxes = [[tb.bbox for tb in table_box.bboxes] for table_box in table_boxes] table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor) table_img = img.copy() - for results, table_bbox in zip(table_preds, layout_tables_bboxes): + for results, table_bbox in zip(table_preds, layout_tables): adjusted_bboxes = [] labels = [] for item in results.cells: diff --git a/ocr_text.py b/ocr_text.py index 5b5bd65..b22ec95 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -30,10 +30,10 @@ def main(): args = parser.parse_args() if os.path.isdir(args.input_path): - images, names = load_from_folder(args.input_path, args.max, args.start_page) + images, names, _ = load_from_folder(args.input_path, args.max, args.start_page) folder_name = os.path.basename(args.input_path) else: - images, names = load_from_file(args.input_path, args.max, args.start_page) + images, names, _ = load_from_file(args.input_path, args.max, args.start_page) folder_name = os.path.basename(args.input_path).split(".")[0] if args.lang_file: diff --git a/reading_order.py b/reading_order.py index cc30ad2..4277a8a 100644 --- a/reading_order.py +++ b/reading_order.py @@ -33,10 +33,10 @@ def main(): det_processor = load_det_processor() if os.path.isdir(args.input_path): - images, names = load_from_folder(args.input_path, args.max) + images, names, _ = load_from_folder(args.input_path, args.max) folder_name = os.path.basename(args.input_path) else: - images, names = load_from_file(args.input_path, args.max) + images, names, _ = load_from_file(args.input_path, args.max) folder_name = os.path.basename(args.input_path).split(".")[0] line_predictions = batch_text_detection(images, det_model, det_processor) diff --git a/surya/input/load.py b/surya/input/load.py index aa8f1a1..304fde4 100644 --- a/surya/input/load.py +++ b/surya/input/load.py @@ -1,5 +1,6 @@ import PIL +from surya.input.pdflines import get_page_text_lines from surya.input.processing import open_pdf, get_page_images import os import filetype @@ -26,15 +27,20 @@ def load_pdf(pdf_path, max_pages=None, start_page=None): page_indices = list(range(start_page, last_page)) images = get_page_images(doc, page_indices) + text_lines = get_page_text_lines( + pdf_path, + page_indices, + [i.size for i in images] + ) doc.close() names = [get_name_from_path(pdf_path) for _ in page_indices] - return images, names + return images, names, text_lines def load_image(image_path): image = Image.open(image_path).convert("RGB") name = get_name_from_path(image_path) - return [image], [name] + return [image], [name], [None] def load_from_file(input_path, max_pages=None, start_page=None): @@ -51,21 +57,24 @@ def load_from_folder(folder_path, max_pages=None, start_page=None): images = [] names = [] + text_lines = [] for path in image_paths: extension = filetype.guess(path) if extension and extension.extension == "pdf": - image, name = load_pdf(path, max_pages, start_page) + image, name, text_line = load_pdf(path, max_pages, start_page) images.extend(image) names.extend(name) + text_lines.extend(text_line) else: try: - image, name = load_image(path) + image, name, text_line = load_image(path) images.extend(image) names.extend(name) + text_lines.extend(text_line) except PIL.UnidentifiedImageError: print(f"Could not load image {path}") continue - return images, names + return images, names, text_lines def load_lang_file(lang_path, names): diff --git a/surya/input/pdflines.py b/surya/input/pdflines.py index d59a497..dbed72b 100644 --- a/surya/input/pdflines.py +++ b/surya/input/pdflines.py @@ -4,26 +4,34 @@ from surya.schema import PolygonBox -def get_page_text_lines(filepath, page_idx, out_size): - full_text = dictionary_output(filepath, sort=False, page_range=[page_idx], keep_chars=True)[0] - text_bbox = full_text["bbox"] - text_w_scale = out_size[0] / text_bbox[2] - text_h_scale = out_size[1] / text_bbox[3] - for block in full_text["blocks"]: - for line in block["lines"]: - line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale, - line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale] - for span in line["spans"]: - for char in span["chars"]: - char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale, - char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale] - return full_text +def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list): + assert len(page_idxs) == len(out_sizes) + pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True) + for full_text, out_size in zip(pages_text, out_sizes): + text_bbox = full_text["bbox"] + text_w_scale = out_size[0] / text_bbox[2] + text_h_scale = out_size[1] / text_bbox[3] + for block in full_text["blocks"]: + for line in block["lines"]: + line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale, + line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale] + for span in line["spans"]: + for char in span["chars"]: + char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale, + char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale] + return pages_text -def get_table_blocks(tables, full_text, img_size): +def get_table_blocks(tables: list, full_text: list, img_size: list, table_thresh=.8): # Returns coordinates relative to input table, not full image table_texts = [] for table in tables: + table_poly = PolygonBox(polygon=[ + [table[0], table[1]], + [table[2], table[1]], + [table[2], table[3]], + [table[0], table[3]] + ]) table_text = [] for block in full_text["blocks"]: for line in block["lines"]: @@ -33,7 +41,7 @@ def get_table_blocks(tables, full_text, img_size): [line["bbox"][2], line["bbox"][3]], [line["bbox"][0], line["bbox"][3]] ]) - if line_poly.intersection_pct(table) < 0.8: + if line_poly.intersection_pct(table_poly) < table_thresh: continue curr_span = None curr_box = None @@ -42,7 +50,7 @@ def get_table_blocks(tables, full_text, img_size): if curr_span is None: curr_span = char["char"] curr_box = char["bbox"] - elif (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01: + elif (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01: curr_span += char["char"] curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]), max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])] @@ -55,10 +63,10 @@ def get_table_blocks(tables, full_text, img_size): # Adjust to be relative to input table for item in table_text: item["bbox"] = [ - item["bbox"][0] - table.bbox[0], - item["bbox"][1] - table.bbox[1], - item["bbox"][2] - table.bbox[0], - item["bbox"][3] - table.bbox[1] + item["bbox"][0] - table[0], + item["bbox"][1] - table[1], + item["bbox"][2] - table[0], + item["bbox"][3] - table[1] ] table_text = sort_text_lines(table_text) table_texts.append(table_text) diff --git a/surya/layout.py b/surya/layout.py index 4702168..9f7dd4d 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -12,7 +12,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: logits = np.stack(heatmaps, axis=0) - vertical_line_bboxes = [line for line in detection_result.vertical_lines] + vertical_line_bboxes = detection_result.vertical_lines line_bboxes = detection_result.bboxes # Scale back to processor size @@ -38,6 +38,8 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea detected_boxes = [] for heatmap_idx in range(1, len(id2label)): # Skip the blank class heatmap = logits[heatmap_idx] + if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: + continue bboxes = get_detected_boxes(heatmap) bboxes = [bbox for bbox in bboxes if bbox.area > 25] for bb in bboxes: @@ -150,10 +152,14 @@ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignm heatmap = heatmaps[i] assert heatmap.shape == segment_assignment.shape heatmap[segment_assignment != i] = 0 # zero out where another segment is + + # Skip processing empty labels + if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: + continue + bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size) for bb in bbox: bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) - heatmaps.append(heatmap) bboxes = keep_largest_boxes(bboxes) return bboxes diff --git a/surya/schema.py b/surya/schema.py index d6d4e7b..8b9204f 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -176,10 +176,12 @@ class OrderResult(BaseModel): class TableCell(Bbox): - row_id: int - col_id: int + row_id: int | None = None + col_id: int | None = None class TableResult(BaseModel): cells: List[TableCell] + rows: List[TableCell] + cols: List[TableCell] image_bbox: List[float] diff --git a/surya/settings.py b/surya/settings.py index 7f0063d..5a8f1ea 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -72,7 +72,7 @@ def TORCH_DEVICE_MODEL(self) -> str: ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" # Table Rec - TABLE_REC_MODEL_CHECKPOINT: str = "vikp/table_rec_ar2" + TABLE_REC_MODEL_CHECKPOINT: str = "vikp/table_rec_ar3" TABLE_REC_IMAGE_SIZE: Dict = {"height": 640, "width": 640} TABLE_REC_MAX_BOXES: int = 512 TABLE_REC_MAX_ROWS: int = 384 diff --git a/surya/tables.py b/surya/tables.py index f3a00be..ac6536e 100644 --- a/surya/tables.py +++ b/surya/tables.py @@ -62,8 +62,8 @@ def corners_to_cx_cy(pred): def snap_to_bboxes(rc_box, input_boxes, used_cells_row, used_cells_col, row=True, row_threshold=.2, col_threshold=.2): sel_bboxes = [] + rc_corner_bbox = cx_cy_to_corners(rc_box) for cell_idx, cell in enumerate(input_boxes): - rc_corner_bbox = cx_cy_to_corners(rc_box) intersection_pct = Bbox(bbox=cell).intersection_pct(Bbox(bbox=rc_corner_bbox)) if row: @@ -97,15 +97,15 @@ def snap_to_bboxes(rc_box, input_boxes, used_cells_row, used_cells_col, row=True -def batch_table_recognition(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: +def batch_table_recognition(images: List, input_bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: assert all([isinstance(image, Image.Image) for image in images]) - assert len(images) == len(bboxes) + assert len(images) == len(input_bboxes) if batch_size is None: batch_size = get_batch_size() output_order = [] - for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): - batch_list_bboxes = deepcopy(bboxes[i:i+batch_size]) + for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"): + batch_list_bboxes = deepcopy(input_bboxes[i:i+batch_size]) batch_list_bboxes = [sort_bboxes(page_bboxes) for page_bboxes in batch_list_bboxes] # Sort bboxes before passing in batch_images = images[i:i+batch_size] @@ -199,62 +199,14 @@ def batch_table_recognition(images: List, bboxes: List[List[List[float]]], model token_count += inference_token_count inference_token_count = batch_decoder_input.shape[1] - """ - for j, (preds, bboxes, orig_size) in enumerate(zip(batch_predictions, batch_list_bboxes, orig_sizes)): - out_data = [] - # They either match up, or there are too many bboxes passed in - img_w, img_h = orig_size - # cx, cy to corners - for i, pred in enumerate(preds): - scale_w = img_w / model.config.decoder.out_box_size - scale_h = img_h / model.config.decoder.out_box_size - class_ = int(pred[4] - SPECIAL_TOKENS) - pred = cx_cy_to_corners(pred) - - preds[i] = [pred[0] * scale_w, pred[1] * scale_h, pred[2] * scale_w, pred[3] * scale_h, class_] - - rows = [p[:4] for p in preds if p[4] == 0] - cols = [p[:4] for p in preds if p[4] == 1] - - for cell in bboxes: - max_intersection = 0 - row_pred = -1 - for row_idx, row in enumerate(rows): - intersection_pct = Bbox(bbox=cell).intersection_pct(Bbox(bbox=row)) - if intersection_pct > max_intersection: - max_intersection = intersection_pct - row_pred = row_idx - - max_intersection = 0 - col_pred = -1 - for col_idx, col in enumerate(cols): - intersection_pct = Bbox(bbox=cell).intersection_pct(Bbox(bbox=col)) - if intersection_pct > max_intersection: - max_intersection = intersection_pct - col_pred = col_idx - - cell = TableCell( - bbox=cell, - col_id=col_pred, - row_id=row_pred - ) - out_data.append(cell) - result = TableResult( - cells=out_data, - image_bbox=[0, 0, img_w, img_h], - ) - - output_order.append(result) - """ for j, (preds, bboxes, orig_size) in enumerate(zip(batch_predictions, batch_list_bboxes, orig_sizes)): - out_data = [] - # They either match up, or there are too many bboxes passed in img_w, img_h = orig_size + width_scaler = img_w / model.config.decoder.out_box_size + height_scaler = img_h / model.config.decoder.out_box_size + # cx, cy to corners for i, pred in enumerate(preds): - width_scaler = img_w / model.config.decoder.out_box_size - height_scaler = img_h / model.config.decoder.out_box_size w = pred[2] / 2 h = pred[3] / 2 x1 = pred[0] - w @@ -265,26 +217,88 @@ def batch_table_recognition(images: List, bboxes: List[List[List[float]]], model preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_] - rows = [p[:4] for p in preds if p[4] == 0] - cols = [p[:4] for p in preds if p[4] == 1] - for row_idx, row in enumerate(rows): + # Get rows and columns + bb_rows = [p[:4] for p in preds if p[4] == 0] + bb_cols = [p[:4] for p in preds if p[4] == 1] + + rows = [] + cols = [] + for row_idx, row in enumerate(bb_rows): cell = TableCell( bbox=row, - col_id=-1, row_id=row_idx ) - out_data.append(cell) + rows.append(cell) - for col_idx, col in enumerate(cols): + for col_idx, col in enumerate(bb_cols): cell = TableCell( bbox=col, col_id=col_idx, - row_id=-1 ) - out_data.append(cell) + cols.append(cell) + + # Assign cells to rows/columns + cells = [] + for cell in bboxes: + max_intersection = 0 + row_pred = None + for row_idx, row in enumerate(rows): + intersection_pct = Bbox(bbox=cell).intersection_pct(row) + if intersection_pct > max_intersection: + max_intersection = intersection_pct + row_pred = row_idx + + max_intersection = 0 + col_pred = None + for col_idx, col in enumerate(cols): + intersection_pct = Bbox(bbox=cell).intersection_pct(col) + if intersection_pct > max_intersection: + max_intersection = intersection_pct + col_pred = col_idx + + cells.append( + TableCell( + bbox=cell, + row_id=row_pred, + col_id=col_pred + ) + ) + + for cell in cells: + if cell.row_id is None: + closest_row = None + closest_row_dist = None + for cell2 in cells: + if cell2.row_id is None: + continue + cell_y_center = (cell.bbox[1] + cell.bbox[3]) / 2 + cell2_y_center = (cell2.bbox[1] + cell2.bbox[3]) / 2 + y_dist = abs(cell_y_center - cell2_y_center) + if closest_row_dist is None or y_dist < closest_row_dist: + closest_row = cell2.row_id + closest_row_dist = y_dist + cell.row_id = closest_row + + if cell.col_id is None: + closest_col = None + closest_col_dist = None + for cell2 in cells: + if cell2.col_id is None: + continue + cell_x_center = (cell.bbox[0] + cell.bbox[2]) / 2 + cell2_x_center = (cell2.bbox[0] + cell2.bbox[2]) / 2 + x_dist = abs(cell2_x_center - cell_x_center) + if closest_col_dist is None or x_dist < closest_col_dist: + closest_col = cell2.col_id + closest_col_dist = x_dist + + cell.col_id = closest_col + result = TableResult( - cells=out_data, + cells=cells, + rows=rows, + cols=cols, image_bbox=[0, 0, img_w, img_h], ) diff --git a/table_recognition.py b/table_recognition.py index c83d2e3..1d7bfbc 100644 --- a/table_recognition.py +++ b/table_recognition.py @@ -1,3 +1,4 @@ +import pypdfium2 as pdfium # Needs to be on top to avoid warning import os import argparse import copy @@ -6,12 +7,13 @@ from surya.detection import batch_text_detection from surya.input.load import load_from_folder, load_from_file +from surya.input.pdflines import get_table_blocks from surya.layout import batch_layout_detection from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor -from surya.model.ordering.model import load_model -from surya.model.ordering.processor import load_processor -from surya.ordering import batch_ordering -from surya.postprocessing.heatmap import draw_polys_on_image +from surya.model.table_rec.model import load_model as load_model +from surya.model.table_rec.processor import load_processor +from surya.tables import batch_table_recognition +from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image from surya.settings import settings @@ -21,6 +23,7 @@ def main(): parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + parser.add_argument("--detect_boxes", action="store_true", help="Detect table boxes.", default=False) args = parser.parse_args() model = load_model() @@ -33,46 +36,92 @@ def main(): det_processor = load_det_processor() if os.path.isdir(args.input_path): - images, names = load_from_folder(args.input_path, args.max) + images, names, text_lines = load_from_folder(args.input_path, args.max) folder_name = os.path.basename(args.input_path) else: - images, names = load_from_file(args.input_path, args.max) + images, names, text_lines = load_from_file(args.input_path, args.max) folder_name = os.path.basename(args.input_path).split(".")[0] + pnums = [] + prev_name = None + for i, name in enumerate(names): + if prev_name is None or prev_name != name: + pnums.append(0) + else: + pnums.append(pnums[-1] + 1) + line_predictions = batch_text_detection(images, det_model, det_processor) layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) table_boxes = [] - for layout_pred in layout_predictions: - bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"] - table_boxes.append(bbox) + table_cells = [] + table_cells_text = [] - order_predictions = batch_ordering(images, bboxes, model, processor) + table_imgs = [] + table_counts = [] + for layout_pred, text_line, img in zip(layout_predictions, text_lines, images): + # The bbox for the entire table + bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"] + # Number of tables per page + table_counts.append(len(bbox)) + + if len(bbox) == 0: + continue + + table_boxes.extend(bbox) + + page_table_imgs = [img.crop(bb) for bb in bbox] + table_imgs.extend(page_table_imgs) + + # The text cells inside each table + if text_line is None or args.detect_boxes: + cell_bboxes = batch_text_detection(page_table_imgs, det_model, det_processor) + cell_bboxes = [[tb.bbox for tb in table_box.bboxes] for table_box in cell_bboxes] + cell_text = [[None for tb in table_box.bboxes] for table_box in cell_bboxes] + table_cells_text.extend(cell_text) + table_cells.extend(cell_bboxes) + else: + table_texts = get_table_blocks(bbox, text_line, img.size) + table_cells.extend( + [[tb["bbox"] for tb in table_text] for table_text in table_texts] + ) + table_cells_text.extend( + [[tb["text"] for tb in table_text] for table_text in table_texts] + ) + + table_preds = batch_table_recognition(table_imgs, table_cells, model, processor) result_path = os.path.join(args.results_dir, folder_name) os.makedirs(result_path, exist_ok=True) if args.images: - for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)): - polys = [l.polygon for l in order_pred.bboxes] - labels = [str(l.position) for l in order_pred.bboxes] - bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20) - bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png")) - - predictions_by_page = defaultdict(list) - for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)): - out_pred = pred.model_dump() - for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes): - bbox["label"] = layout_bbox.label + pass - out_pred["page"] = len(predictions_by_page[name]) + 1 - predictions_by_page[name].append(out_pred) + img_idx = 0 + prev_count = 0 + table_predictions = defaultdict(list) + for i in range(sum(table_counts)): + while i >= prev_count + table_counts[img_idx]: + prev_count += table_counts[img_idx] + img_idx += 1 + + pred = table_preds[i] + orig_name = names[img_idx] + pnum = pnums[img_idx] + table_img = table_imgs[i] + + out_pred = pred.model_dump() + out_pred["page"] = pnum + 1 + table_idx = i - prev_count + out_pred["table_idx"] = table_idx + table_predictions[orig_name].append(out_pred) - # Sort in reading order - for name in predictions_by_page: - for page_preds in predictions_by_page[name]: - page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"]) + if args.images: + boxes = [l.bbox for l in pred.cells] + labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells] + bbox_image = draw_bboxes_on_image(boxes, table_img, labels=labels, label_font_size=20) + bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_table.png")) with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: - json.dump(predictions_by_page, f, ensure_ascii=False) + json.dump(table_predictions, f, ensure_ascii=False) print(f"Wrote results to {result_path}")