Skip to content

Commit

Permalink
Enable instantiating model with pretrained backbone weights (huggingf…
Browse files Browse the repository at this point in the history
…ace#28214)

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Remove doc updates until changes made in modeling code

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py

Co-authored-by: Arthur <[email protected]>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Small test updates

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
amyeroberts and ArthurZucker authored Jan 23, 2024
1 parent 008a6a2 commit 27c79a0
Show file tree
Hide file tree
Showing 31 changed files with 362 additions and 37 deletions.
7 changes: 2 additions & 5 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,6 @@ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *mod

config = kwargs.pop("config", TimmBackboneConfig())

use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")

if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")

Expand All @@ -627,7 +623,8 @@ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *mod

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs.get("use_timm_backbone", False):
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -180,6 +180,14 @@ def __init__(
focal_alpha=0.25,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -196,6 +196,14 @@ def __init__(
disable_custom_kernels=False,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/deta/configuration_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DetaConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
Expand Down Expand Up @@ -138,6 +144,8 @@ class DetaConfig(PretrainedConfig):
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
num_queries=900,
max_position_embeddings=2048,
encoder_layers=6,
Expand Down Expand Up @@ -177,7 +185,13 @@ def __init__(
focal_alpha=0.25,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage2", "stage3", "stage4"])
else:
Expand All @@ -187,6 +201,8 @@ def __init__(
backbone_config = config_class.from_dict(backbone_config)

self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
Expand Down
18 changes: 13 additions & 5 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ class DetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -177,6 +177,14 @@ def __init__(
eos_coefficient=0.1,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class DPTConfig(PretrainedConfig):
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
leverage the [`AutoBackbone`] API.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
Example:
Expand Down Expand Up @@ -161,16 +167,24 @@ def __init__(
backbone_featmap_shape=[1, 1024, 24, 24],
neck_ignore_stages=[0, 1],
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
**kwargs,
):
super().__init__(**kwargs)

self.hidden_size = hidden_size
self.is_hybrid = is_hybrid

if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

use_autobackbone = False
if self.is_hybrid:
if backbone_config is None:
if backbone_config is None and backbone is None:
logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = {
"global_padding": "same",
Expand Down Expand Up @@ -213,6 +227,8 @@ def __init__(
self.backbone_featmap_shape = None
self.neck_ignore_stages = []

self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/models/mask2former/configuration_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`):
The configuration of the backbone model. If unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
Expand Down Expand Up @@ -154,9 +160,17 @@ def __init__(
use_auxiliary_loss: bool = True,
feature_strides: List[int] = [4, 8, 16, 32],
output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
Expand All @@ -177,7 +191,7 @@ def __init__(
backbone_config = config_class.from_dict(backbone_config)

# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with Mask2Former. "
f"Supported model types: {','.join(self.backbones_supported)}"
Expand Down Expand Up @@ -212,6 +226,8 @@ def __init__(
self.feature_strides = feature_strides
self.output_auxiliary_logits = output_auxiliary_logits
self.num_hidden_layers = decoder_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone

super().__init__(**kwargs)

Expand Down
20 changes: 18 additions & 2 deletions src/transformers/models/maskformer/configuration_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class MaskFormerConfig(PretrainedConfig):
backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
Expand Down Expand Up @@ -114,9 +120,17 @@ def __init__(
cross_entropy_weight: float = 1.0,
mask_weight: float = 20.0,
output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is None and backbone is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig(
image_size=384,
Expand All @@ -136,7 +150,7 @@ def __init__(
backbone_config = config_class.from_dict(backbone_config)

# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. "
f"Supported model types: {','.join(self.backbones_supported)}"
Expand Down Expand Up @@ -177,6 +191,8 @@ def __init__(

self.num_attention_heads = self.decoder_config.encoder_attention_heads
self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
super().__init__(**kwargs)

@classmethod
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/oneformer/configuration_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class OneFormerConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150):
Expand Down Expand Up @@ -144,6 +150,8 @@ class OneFormerConfig(PretrainedConfig):
def __init__(
self,
backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
Expand Down Expand Up @@ -186,7 +194,13 @@ def __init__(
common_stride: int = 4,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

if backbone_config is None and backbone is None:
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
Expand All @@ -206,7 +220,8 @@ def __init__(
backbone_config = config_class.from_dict(backbone_config)

self.backbone_config = backbone_config

self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
Expand Down
Loading

0 comments on commit 27c79a0

Please sign in to comment.