Skip to content

Commit

Permalink
fix: linknet hyperparameters postprocessing + demo for rotation model (
Browse files Browse the repository at this point in the history
…mindee#865)

* fix: linknet parameters

* feat: add demo rotation

* feat: add rotation in demo
  • Loading branch information
charlesmindee authored Mar 22, 2022
1 parent 9878d03 commit 9d03085
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
12 changes: 8 additions & 4 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from doctr.models import ocr_predictor
from doctr.utils.visualization import visualize_page

DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18_rotation"]
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]


Expand Down Expand Up @@ -73,7 +73,10 @@ def main():

else:
with st.spinner('Loading model...'):
predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
predictor = ocr_predictor(
det_arch, reco_arch, pretrained=True,
assume_straight_pages=(det_arch != "linknet_resnet18_rotation")
)

with st.spinner('Analyzing...'):

Expand All @@ -97,8 +100,9 @@ def main():

# Page reconsitution under input page
page_export = out.pages[0].export()
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)
if det_arch != "linknet_resnet18_rotation":
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)

# Display JSON
st.markdown("\nHere are your analysis results in JSON format:")
Expand Down
7 changes: 3 additions & 4 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
"""
def __init__(
self,
bin_thresh: float = 0.5,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
) -> None:
Expand All @@ -39,7 +39,7 @@ def __init__(
bin_thresh,
assume_straight_pages
)
self.unclip_ratio = 1.5
self.unclip_ratio = 1.2

def polygon_to_box(
self,
Expand Down Expand Up @@ -103,13 +103,12 @@ def bitmap_to_boxes(
containing x, y, w, h, alpha, score for the box
"""
height, width = bitmap.shape[:2]
min_size_box = 1 + int(height / 512)
boxes = []
# get contours from connected components on the bitmap
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
continue
# Compute objectness
if self.assume_straight_pages:
Expand Down
18 changes: 10 additions & 8 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,16 @@ def visualize_page(
int(page['dimensions'][1] * word['geometry'][0][0]),
int(page['dimensions'][0] * word['geometry'][0][1])
)
ax.text(
*text_loc,
word['value'],
size=10,
alpha=0.5,
color=(0, 0, 1),
)

if len(word['geometry']) == 2:
# We draw only if boxes are in straight format
ax.text(
*text_loc,
word['value'],
size=10,
alpha=0.5,
color=(0, 0, 1),
)

if display_artefacts:
for artefact in block['artefacts']:
Expand All @@ -251,7 +254,6 @@ def visualize_page(
def synthesize_page(
page: Dict[str, Any],
draw_proba: bool = False,
font_size: int = 13,
font_family: Optional[str] = None,
) -> np.ndarray:
"""Draw a the content of the element page (OCR response) on a blank page.
Expand Down

0 comments on commit 9d03085

Please sign in to comment.