diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 64bda8bcd..3e78f4756 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -29,7 +29,7 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.2.0/db_resnet50-adcafc63.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0", }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), @@ -147,20 +147,24 @@ def __init__( _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] output_shape = tuple(self.fpn(_inputs).shape) - self.probability_head = keras.Sequential([ - *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), - layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), - layers.BatchNormalization(), - layers.Activation("relu"), - layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), - ]) - self.threshold_head = keras.Sequential([ - *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), - layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), - layers.BatchNormalization(), - layers.Activation("relu"), - layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), - ]) + self.probability_head = keras.Sequential( + [ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ] + ) + self.threshold_head = keras.Sequential( + [ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ] + ) self.postprocessor = DBPostProcessor( assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh