Skip to content

Commit

Permalink
feat: ✨ tf db_resnet50 checkpoint (mindee#1480)
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee authored Feb 23, 2024
1 parent d547ef9 commit b2f9b17
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.2.0/db_resnet50-adcafc63.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
},
"db_mobilenet_v3_large": {
"mean": (0.798, 0.785, 0.772),
Expand Down Expand Up @@ -147,20 +147,24 @@ def __init__(
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
output_shape = tuple(self.fpn(_inputs).shape)

self.probability_head = keras.Sequential([
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])
self.threshold_head = keras.Sequential([
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])
self.probability_head = keras.Sequential(
[
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
]
)
self.threshold_head = keras.Sequential(
[
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
]
)

self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
Expand Down

0 comments on commit b2f9b17

Please sign in to comment.