Skip to content

Commit

Permalink
refactor type to module_type (huggingface#70)
Browse files Browse the repository at this point in the history
* refactor `type` to `module_type`

* fix style

* fix last types

* remove accidentally commited file
  • Loading branch information
lvwerra authored May 27, 2022
1 parent 4852bcd commit 5051232
Show file tree
Hide file tree
Showing 14 changed files with 56 additions and 51 deletions.
2 changes: 1 addition & 1 deletion comparisons/mcnemar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from evaluate.utils import launch_gradio_widget


module = evaluate.load("mcnemar", type="comparison")
module = evaluate.load("mcnemar", module_type="comparison")
launch_gradio_widget(module)
2 changes: 1 addition & 1 deletion comparisons/mcnemar/mcnemar.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
class McNemar(evaluate.EvaluationModule):
def _info(self):
return evaluate.EvaluationModuleInfo(
type="comparison",
module_type="comparison",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/a_quick_tour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ Any metric, comparison, or measurement is loaded with the `evaluate.load` functi
If you want to make sure you are loading the right type of module (especially if there are name clashes) you can explicitely pass the type:

```py
>>> word_length = evaluate.load("word_length", type="measurement")
>>> word_length = evaluate.load("word_length", module_type="measurement")
```

### Community modules

Besides the modules implemented in 🤗 Evaluate you can also load any community module by prepending the users name:

```py
>>> element_count = evaluate.load("lvwerra/element_count", type="measurement")
>>> element_count = evaluate.load("lvwerra/element_count", module_type="measurement")
```

## Module attributes
Expand Down
6 changes: 3 additions & 3 deletions measurements/word_length/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This measurement requires a list of strings as input:

```python
>>> data = ["hello world"]
>>> wordlength = evaluate.load("word_length", type="measurement")
>>> wordlength = evaluate.load("word_length", module_type="measurement")
>>> results = wordlength.compute(data=data)
```

Expand All @@ -50,7 +50,7 @@ Example for a single string

```python
>>> data = ["hello sun and goodbye moon"]
>>> wordlength = evaluate.load("word_length", type="measurement")
>>> wordlength = evaluate.load("word_length", module_type="measurement")
>>> results = wordlength.compute(data=data)
>>> print(results)
{'average_length': 5}
Expand All @@ -59,7 +59,7 @@ Example for a single string
Example for a multiple strings
```python
>>> data = ["hello sun and goodbye moon", "foo bar foo bar"]
>>> wordlength = evaluate.load("word_length", type="measurement")
>>> wordlength = evaluate.load("word_length", module_type="measurement")
>>> results = wordlength.compute(data=text)
{'average_length': 4.5}
```
Expand Down
2 changes: 1 addition & 1 deletion measurements/word_length/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from evaluate.utils import launch_gradio_widget


module = evaluate.load("word_length", type="measurement")
module = evaluate.load("word_length", module_type="measurement")
launch_gradio_widget(module)
4 changes: 2 additions & 2 deletions measurements/word_length/word_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Examples:
>>> data = ["hello world"]
>>> wordlength = evaluate.load("word_length", type="measurement")
>>> wordlength = evaluate.load("word_length", module_type="measurement")
>>> results = wordlength.compute(data=data)
>>> print(results)
{'average_word_length': 2}
Expand All @@ -57,7 +57,7 @@ def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.EvaluationModuleInfo(
# This is the description that will appear on the modules page.
type="measurement",
module_type="measurement",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
Expand Down
6 changes: 3 additions & 3 deletions metrics/bleurt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ This metric takes as input lists of predicted sentences and reference sentences:
```python
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> bleurt = load("bleurt", type="metric")
>>> bleurt = load("bleurt", module_type="metric")
>>> results = bleurt.compute(predictions=predictions, references=references)
```

Expand Down Expand Up @@ -63,7 +63,7 @@ Example with the default model:
```python
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> bleurt = load("bleurt", type="metric")
>>> bleurt = load("bleurt", module_type="metric")
>>> results = bleurt.compute(predictions=predictions, references=references)
>>> print(results)
{'scores': [1.0295498371124268, 1.0445425510406494]}
Expand All @@ -73,7 +73,7 @@ Example with the `"bleurt-base-128"` model checkpoint:
```python
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> bleurt = load("bleurt", type="metric", checkpoint="bleurt-base-128")
>>> bleurt = load("bleurt", module_type="metric", checkpoint="bleurt-base-128")
>>> results = bleurt.compute(predictions=predictions, references=references)
>>> print(results)
{'scores': [1.0295498371124268, 1.0445425510406494]}
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class EvaluationModuleInfo:
reference_urls: List[str] = field(default_factory=list)
streamable: bool = False
format: Optional[str] = None
type: str = "metric"
module_type: str = "metric"

# Set later by the builder
metric_name: Optional[str] = None
Expand Down
52 changes: 27 additions & 25 deletions src/evaluate/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,14 @@ class GithubEvaluationModuleFactory(_EvaluationModuleFactory):
def __init__(
self,
name: str,
type: str,
module_type: str,
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.type = type
self.module_type = module_type
self.revision = revision
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
Expand All @@ -423,7 +423,9 @@ def __init__(
increase_load_count(name, resource_type="metric")

def download_loading_script(self, revision: Optional[str]) -> str:
file_path = hf_github_url(path=self.name, name=self.name + ".py", type=self.type, revision=revision)
file_path = hf_github_url(
path=self.name, name=self.name + ".py", module_type=self.module_type, revision=revision
)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
Expand All @@ -448,7 +450,7 @@ def get_module(self) -> ImportableModule:
imports = get_imports(local_path)
local_imports = _download_additional_modules(
name=self.name,
base_path=hf_github_url(path=self.name, name="", type=self.type, revision=revision),
base_path=hf_github_url(path=self.name, name="", module_type=self.module_type, revision=revision),
imports=imports,
download_config=self.download_config,
)
Expand All @@ -459,7 +461,7 @@ def get_module(self) -> ImportableModule:
local_imports=local_imports,
additional_files=[],
dynamic_modules_path=dynamic_modules_path,
module_namespace=self.type,
module_namespace=self.module_type,
name=self.name,
download_mode=self.download_mode,
)
Expand All @@ -474,13 +476,13 @@ class LocalEvaluationModuleFactory(_EvaluationModuleFactory):
def __init__(
self,
path: str,
type: str = "metrics",
module_type: str = "metrics",
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
dynamic_modules_path: Optional[str] = None,
):
self.path = path
self.type = type
self.module_type = module_type
self.name = Path(path).stem
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
Expand All @@ -502,7 +504,7 @@ def get_module(self) -> ImportableModule:
local_imports=local_imports,
additional_files=[],
dynamic_modules_path=dynamic_modules_path,
module_namespace=self.type,
module_namespace=self.module_type,
name=self.name,
download_mode=self.download_mode,
)
Expand All @@ -517,14 +519,14 @@ class HubEvaluationModuleFactory(_EvaluationModuleFactory):
def __init__(
self,
name: str,
type: str = "metrics",
module_type: str = "metrics",
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.type = type
self.module_type = module_type
self.revision = revision
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
Expand Down Expand Up @@ -556,7 +558,7 @@ def get_module(self) -> ImportableModule:
local_imports=local_imports,
additional_files=[],
dynamic_modules_path=dynamic_modules_path,
module_namespace=self.type,
module_namespace=self.module_type,
name=self.name,
download_mode=self.download_mode,
)
Expand All @@ -574,17 +576,17 @@ class CachedEvaluationModuleFactory(_EvaluationModuleFactory):
def __init__(
self,
name: str,
type: str = "metrics",
module_type: str = "metrics",
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.type = type
self.module_type = module_type
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") == 0

def get_module(self) -> ImportableModule:
dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
importable_directory_path = os.path.join(dynamic_modules_path, self.type, self.name)
importable_directory_path = os.path.join(dynamic_modules_path, self.module_type, self.name)
hashes = (
[h for h in os.listdir(importable_directory_path) if len(h) == 64]
if os.path.isdir(importable_directory_path)
Expand All @@ -604,14 +606,14 @@ def _get_modification_time(module_hash):
f"couldn't be found locally at {self.name}, or remotely on the Hugging Face Hub."
)
# make the new module to be noticed by the import system
module_path = ".".join([os.path.basename(dynamic_modules_path), self.type, self.name, hash, self.name])
module_path = ".".join([os.path.basename(dynamic_modules_path), self.module_type, self.name, hash, self.name])
importlib.invalidate_caches()
return ImportableModule(module_path, hash)


def evaluation_module_factory(
path: str,
type: Optional[str] = None,
module_type: Optional[str] = None,
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
Expand Down Expand Up @@ -680,7 +682,7 @@ def evaluation_module_factory(
# load a canonical evaluation module from hub
if path.count("/") == 0:
# if no type provided look through all possible modules
if type is None:
if module_type is None:
for current_type in ["metric", "comparison", "measurement"]:
try:
return GithubEvaluationModuleFactory(
Expand All @@ -694,11 +696,11 @@ def evaluation_module_factory(
except FileNotFoundError:
pass
raise FileNotFoundError
# if type provided load specific type
# if module_type provided load specific module_type
else:
return GithubEvaluationModuleFactory(
path,
type,
module_type,
revision=revision,
download_config=download_config,
download_mode=download_mode,
Expand Down Expand Up @@ -730,7 +732,7 @@ def evaluation_module_factory(
def load(
path: str,
config_name: Optional[str] = None,
type: Optional[str] = None,
module_type: Optional[str] = None,
process_id: int = 0,
num_process: int = 1,
cache_dir: Optional[str] = None,
Expand All @@ -750,9 +752,9 @@ def load(
- a local path to processing script or the directory containing the script (if the script has the same name as the directory),
e.g. ``'./metrics/rouge'`` or ``'./metrics/rogue/rouge.py'``
- a evaluation module identifier on the HuggingFace evaluate repo e.g. ``'rouge'`` or ``'bleu'`` that are in either ``'metrics/'``,
``'comparisons/'``, or ``'measurements/'`` depending on the provided ``type``.
``'comparisons/'``, or ``'measurements/'`` depending on the provided ``module_type``.
config_name (:obj:`str`, optional): selecting a configuration for the metric (e.g. the GLUE metric has a configuration for each subset)
type (:obj:`str`, default ``'metric'``): type of evaluation module, can be one of ``'metric'``, ``'comparison'``, or ``'measurement'``.
module_type (:obj:`str`, default ``'metric'``): type of evaluation module, can be one of ``'metric'``, ``'comparison'``, or ``'measurement'``.
process_id (:obj:`int`, optional): for distributed evaluation: id of the process
num_process (:obj:`int`, optional): for distributed evaluation: total number of processes
cache_dir (Optional str): path to store the temporary predictions and references (default to `~/.cache/huggingface/evaluate/`)
Expand All @@ -770,7 +772,7 @@ def load(
"""
download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS)
evaluation_module = evaluation_module_factory(
path, type=type, revision=revision, download_config=download_config, download_mode=download_mode
path, module_type=module_type, revision=revision, download_config=download_config, download_mode=download_mode
).module_path
evaluation_cls = import_main_class(evaluation_module)
evaluation_instance = evaluation_cls(
Expand All @@ -783,9 +785,9 @@ def load(
**init_kwargs,
)

if type and type != evaluation_instance.type:
if module_type and module_type != evaluation_instance.module_type:
raise TypeError(
f"No module of type '{type}' not found for '{path}' locally, or on the Hugging Face Hub. Found module of type '{evaluation_instance.type}' instead."
f"No module of module type '{module_type}' not found for '{path}' locally, or on the Hugging Face Hub. Found module of module type '{evaluation_instance.module_type}' instead."
)

# Download and prepare resources for the metric
Expand Down
6 changes: 3 additions & 3 deletions src/evaluate/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def format(self) -> Optional[str]:
return self._module_info.format

@property
def type(self) -> str:
return self._module_info.type
def module_type(self) -> str:
return self._module_info.module_type


class EvaluationModule(EvaluationModuleInfoMixin):
Expand Down Expand Up @@ -238,7 +238,7 @@ def __len__(self):

def __repr__(self):
return (
f'EvaluationModule(name: "{self.name}", type: "{self.type}", '
f'EvaluationModule(name: "{self.name}", module_type: "{self.module_type}", '
f'features: {self.features}, usage: """{self.inputs_description}""", '
f"stored examples: {len(self)})"
)
Expand Down
10 changes: 5 additions & 5 deletions src/evaluate/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,19 @@ def head_hf_s3(
)


def hf_github_url(path: str, name: str, type: str, revision: Optional[str] = None) -> str:
def hf_github_url(path: str, name: str, module_type: str, revision: Optional[str] = None) -> str:
from .. import SCRIPTS_VERSION

revision = revision or os.getenv("HF_SCRIPTS_VERSION", SCRIPTS_VERSION)
if type == "metric":
if module_type == "metric":
return config.REPO_METRICS_URL.format(revision=revision, path=path, name=name)
elif type == "comparison":
elif module_type == "comparison":
return config.REPO_COMPARISONS_URL.format(revision=revision, path=path, name=name)
elif type == "measurement":
elif module_type == "measurement":
return config.REPO_MEASUREMENTS_URL.format(revision=revision, path=path, name=name)
else:
raise TypeError(
f"The evaluation type {type} is not supported. Should be one of 'metric', 'comparison', 'measurement'"
f"The evaluation module type {module_type} is not supported. Should be one of 'metric', 'comparison', 'measurement'"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.EvaluationModuleInfo(
# This is the description that will appear on the modules page.
type="{{ cookiecutter.module_type }}",
module_type="{{ cookiecutter.module_type }}",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_GithubMetricModuleFactory_with_internal_import(self):
# "squad_v2" requires additional imports (internal)
factory = GithubEvaluationModuleFactory(
"squad_v2",
type="metric",
module_type="metric",
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
)
Expand All @@ -75,7 +75,10 @@ def test_GithubMetricModuleFactory_with_internal_import(self):
def test_GithubMetricModuleFactory_with_external_import(self):
# "bleu" requires additional imports (external from github)
factory = GithubEvaluationModuleFactory(
"bleu", type="metric", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
"bleu",
module_type="metric",
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
Expand Down
Loading

0 comments on commit 5051232

Please sign in to comment.