Skip to content

Commit

Permalink
Throw an error if getattribute_from_module can't find anything (hug…
Browse files Browse the repository at this point in the history
…gingface#19535)

* return None to avoid recursive call

* Give error

* Give error

* Add test

* More tests

* Quality

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Oct 12, 2022
1 parent 383ad81 commit 0968388
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,14 @@ def getattribute_from_module(module, attr):
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)

if module != transformers_module:
try:
return getattribute_from_module(transformers_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
else:
raise ValueError(f"Could not find {attr} in {transformers_module}!")


class _LazyAutoMapping(OrderedDict):
Expand Down
24 changes: 23 additions & 1 deletion tests/models/auto/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import sys
import tempfile
import unittest
from collections import OrderedDict
from pathlib import Path

from transformers import BertConfig, is_torch_available
import pytest

from transformers import BertConfig, GPT2Model, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
Expand Down Expand Up @@ -372,3 +375,22 @@ def test_cached_model_has_minimum_calls_to_head(self):
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)

def test_attr_not_existing(self):

from transformers.models.auto.auto_factory import _LazyAutoMapping

_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)

with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
_MODEL_MAPPING[BertConfig]

_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)

_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)

0 comments on commit 0968388

Please sign in to comment.