Skip to content

Commit

Permalink
updated main to adaparse incl. pagewise
Browse files Browse the repository at this point in the history
  • Loading branch information
7shoe committed Dec 4, 2024
1 parent 60683a2 commit 7db44a8
Show file tree
Hide file tree
Showing 102 changed files with 454 additions and 674,408 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
311 changes: 311 additions & 0 deletions adaparse/parsers/adaparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
"""The AdaParse PDF parser."""

from __future__ import annotations

import functools
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Any
from typing import Literal

import torch
from pydantic import BaseModel
from pydantic import Field
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from pdfwf.parsers.base import BaseParser
from pdfwf.parsers.nougat_ import NougatParser
from pdfwf.parsers.nougat_ import NougatParserConfig
from pdfwf.parsers.pymupdf import PyMuPDFParser
from pdfwf.parsers.pymupdf import PyMuPDFParserConfig
from pdfwf.timer import Timer
from pdfwf.utils import exception_handler

__all__ = [
'AdaParse',
'AdaParseConfig',
]


class TextDataset(Dataset):
"""Dataset for sequence classification."""

def __init__(self, texts: list[str]) -> None:
"""Initialize the dataset."""
self.texts = texts

def __len__(self) -> int:
"""Return the number of text."""
return len(self.texts)

def __getitem__(self, idx: int) -> str:
"""Return a sequence."""
return self.texts[idx]


class TextClassifierConfig(BaseModel):
"""Settings for the text classifier."""

weights_path: Path = Field(
description='The path to the fine-tuned model weights.',
)
batch_size: int = Field(
default=8,
description='The batch size for the classifier.',
)
max_character_length: int = Field(
default=3200,
description='The maximum length of the input text (in characters).',
)
num_data_workers: int = Field(
default=1,
description='The number of data workers for the classifier.',
)
pin_memory: bool = Field(
default=True,
description='Whether to pin memory for the classifier.',
)


class TextClassifier(ABC):
"""Text classifier."""

def __init__(self, config: TextClassifierConfig) -> None:
"""Initialize the classifier."""
from peft import PeftModel
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load the base model
model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=11
)

# Load the fine-tuned model with LoRA adapters
model = PeftModel.from_pretrained(model, config.weights_path)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Set the model to evaluation mode
model.eval()

self.config = config
self.device = device
self.model = model
self.tokenizer = tokenizer

@abstractmethod
def decision_function(self, logits: torch.Tensor) -> torch.Tensor:
"""Return the decision function.
Parameters
----------
logits : torch.Tensor
The model logits.
Returns
-------
torch.Tensor
The decision function result (tensor of ints).
"""
...

@torch.no_grad()
def predict(self, text: list[str]) -> list[int]:
"""Classify the input text.
Parameters
----------
text : list[str]
The input text to classify.
Returns
-------
list[int]
The predicted classes.
"""
# Truncate the text
_text = [t[: self.config.max_character_length] for t in text]

# Create the dataset
dataset = TextDataset(_text)

# Create the data collator (tokenization function)
collater_fn = functools.partial(
self.tokenizer,
return_tensors='pt',
truncation=True,
padding=True,
return_special_tokens_mask=False,
)

# Create the data loader
dataloader = DataLoader(
dataset,
collate_fn=collater_fn,
batch_size=self.config.batch_size,
pin_memory=self.config.pin_memory,
num_workers=self.config.num_data_workers,
)

# Collect the predictions
predictions = []

# Iterate over each batch of the data loader
for batch in dataloader:
# Move the inputs to the appropriate device
inputs = {k: v.to(self.device) for k, v in batch.items()}

# Run the model forward pass
outputs = self.model(**inputs)

# Call the decision function
y_pred = self.decision_function(outputs.logits)

# Collect the predictions
predictions.extend(y_pred.tolist())

return predictions


class NougatTextClassifier(TextClassifier):
"""Text classifier for the Nougat parser."""

def decision_function(self, logits: torch.Tensor) -> torch.Tensor:
"""Return the decision function.
Parameters
----------
logits : torch.Tensor
The model logits.
Returns
-------
torch.Tensor
The decision function result (tensor of ints).
"""
# Get the predicted classes
y_pred = logits.argmax(dim=1)

# We only care about the class 0 (high quality) and class 1
# (low quality). Assign 0 to class 0 and 1 to all other classes.

# NEW (SIMPLE PERF TEST)
#probability = 0.05 # 5%
#mask = torch.rand(y_pred.shape, dtype=torch.float32) < probability
#mask = mask.to(y_pred.dtype)
#y_pred[mask.bool()] = 1 # Ensure the mask is boolean for indexing
# legacy: always 0
y_pred[y_pred != 0] = 0

return y_pred


class AdaParseConfig(
PyMuPDFParserConfig, NougatParserConfig, TextClassifierConfig
):
"""Settings for the AdaParse parser."""

# The name of the parser.
name: Literal['adaparse'] = 'adaparse' # type: ignore[assignment]

# DEV NOTE: The following are convenience properties to access the
# individual parser configurations (we need a flat configuration for
# the parser to be compatible with the warmstart registry module).
@property
def pymupdf_config(self) -> PyMuPDFParserConfig:
"""Return the PyMuPDF parser configuration."""
return PyMuPDFParserConfig()

@property
def nougat_config(self) -> NougatParserConfig:
"""Return the Nougat parser configuration."""
return NougatParserConfig(
batchsize=self.batchsize,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
checkpoint=self.checkpoint,
mmd_out=self.mmd_out,
recompute=self.recompute,
full_precision=self.full_precision,
markdown=self.markdown,
skipping=self.skipping,
nougat_logs_path=self.nougat_logs_path,
)

@property
def classifier_config(self) -> TextClassifierConfig:
"""Return the text classifier configuration."""
return TextClassifierConfig(
weights_path=self.weights_path,
batch_size=self.batch_size,
max_character_length=self.max_character_length,
num_data_workers=self.num_data_workers,
pin_memory=self.pin_memory,
)


class AdaParse(BaseParser):
"""Interface for the AdaParse PDF parser."""

def __init__(self, config: AdaParseConfig) -> None:
"""Initialize the parser."""
# Initialize the PyMuPDF and Nougat parsers
self.pymudf_parser = PyMuPDFParser(config=config.pymupdf_config)
self.nougat_parser = NougatParser(config=config.nougat_config)

# Initialize the quality check classifier
# Return a 0 or 1 for each parsed text. If 0, the pdf text, as parsed
# by pymupdf is of high quality. If not 0, the pdf text should be
# parsed with Nougat.
self.classifier = NougatTextClassifier(config=config.classifier_config)

@exception_handler(default_return=None)
def parse(self, pdf_files: list[str]) -> list[dict[str, Any]] | None:
"""Parse a list of pdf files and return the parsed data."""
# First, parse the PDFs using PyMuPDF
with Timer('adaparse-pymupdf-parsing', self.unique_id):
documents = self.pymudf_parser.parse(pdf_files)

# If no documents, there was an error parsing the PDFs with PyMuPDF
if documents is None:
return None

# Apply the quality check regressor
with Timer('adaparse-quality-check', self.unique_id):
document_text = [d['text'] for d in documents]
qualities = self.classifier.predict(document_text)

# Log the percentage of low-quality documents
low_quality_num = sum(q != 0 for q in qualities)
low_quality_percentage = (low_quality_num / len(qualities)) * 100
print(f'Low-quality documents: {low_quality_percentage:.2f}%')

# Collect the documents that passed the quality check
documents = [d for d, q in zip(documents, qualities) if q == 0]

# Collect the pdf files that failed the quality check
low_quality_pdfs = [p for p, q in zip(pdf_files, qualities) if q != 0]

# If no low-quality documents, return the parsed documents
if not low_quality_pdfs:
return documents

# Parse the low-quality documents using the Nougat parser
with Timer('adaparse-nougat-parsing', self.unique_id):
nougat_documents = self.nougat_parser.parse(low_quality_pdfs)

# If Nougat documents were parsed, add them to the output
if nougat_documents is not None:
print(f'Nougat parsed documents: {len(nougat_documents)}')
documents.extend(nougat_documents)

# Finally, return the parsed documents from both parsers
return documents
File renamed without changes.
File renamed without changes.
11 changes: 8 additions & 3 deletions pdfwf/parsers/nougat_.py → adaparse/parsers/nougat_.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,15 @@ def parse(self, pdf_files: list[str]) -> list[dict[str, Any]] | None: # noqa: P
if is_last_page[j]:
out = ''.join(predictions).strip()
out = re.sub(r'\n{3,}', '\n\n', out).strip()

# derive (approximate) page start character indices
page_indices = [0] + [len(pred) for pred in predictions[:-1]]

# TODO: Implement an LLM-based optional metadata extraction
# call to run on the first page for author and title.
document = {'path': str(is_last_page[j]), 'text': out}
# metadata
metadata = {'page_char_idx' : page_indices}

# write document
document = {'path': str(is_last_page[j]), 'text': out, 'metadata' : metadata, 'parser' : 'nougat'}
documents.append(document)

if self.config.mmd_out is not None:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
24 changes: 20 additions & 4 deletions pdfwf/parsers/pymupdf.py → adaparse/parsers/pymupdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,22 @@ def parse_pdf(self, pdf_path: str) -> tuple[str, dict[str, str]] | None:

# Scrape text
text_list = []
# track char page indices
cumm_idx = 0
page_indices = [0]

# loop pages
for page in doc:
text_list.append(page.get_text())
# - page's text
page_txt = page.get_text()
text_list.append(page_txt)
# - char indices
cumm_idx+=(len(page_txt) + len('\n'))
page_indices.append(cumm_idx)

# remove trailing index
page_indices = page_indices[:-1]

full_text = '\n'.join(text_list)

# Get first page (as a proxy for `abstract`)
Expand Down Expand Up @@ -90,10 +104,11 @@ def parse_pdf(self, pdf_path: str) -> tuple[str, dict[str, str]] | None:
'format': form,
'first_page': first_page_text,
'abstract': abstract,
'page_char_idx' : page_indices
}

# TODO: Should we close the document?
# doc.close()
# explicitely close doc
doc.close()

# full text & metadata entries
return full_text, out_meta
Expand All @@ -120,6 +135,7 @@ def parse(self, pdf_files: list[str]) -> list[dict[str, Any]] | None:
'text': text,
'path': str(pdf_file),
'metadata': metadata,
'parser': self.config.name,
}
documents.append(document)

Expand Down
Loading

0 comments on commit 7db44a8

Please sign in to comment.