Skip to content

Commit

Permalink
Merge pull request allegro#18 from allegro/dependency_update
Browse files Browse the repository at this point in the history
Dependency update + two minor fixes
  • Loading branch information
PrzemekPobrotyn authored Sep 7, 2020
2 parents 6167f9d + abac065 commit c90b64a
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 25 deletions.
8 changes: 4 additions & 4 deletions allrank/click_models/click_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List
from typing import List, Tuple, Union

import numpy as np
import torch
Expand All @@ -7,8 +7,8 @@
from allrank.data.dataset_loading import PADDED_Y_VALUE


def click_on_slates(slates: Tuple[torch.Tensor, torch.Tensor], click_model: ClickModel, include_empty: bool) \
-> Tuple[List[torch.Tensor], List[List[int]]]:
def click_on_slates(slates: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]],
click_model: ClickModel, include_empty: bool) -> Tuple[List[Union[np.ndarray, torch.Tensor]], List[List[int]]]:
"""
This metod runs a click model on a list of slates and returns new slates with `y` taken from clicks
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, inner_click_model: ClickModel):
"""
self.inner_click_model = inner_click_model

def click(self, documents: Tuple[torch.Tensor, torch.Tensor]):
def click(self, documents: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]]) -> np.ndarray:
X, y = documents
padded_values_mask = y == PADDED_Y_VALUE
real_X = X[~padded_values_mask]
Expand Down
4 changes: 2 additions & 2 deletions allrank/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _parse_metrics(metrics):
metrics_dict = defaultdict(list) # type: Dict[str, list]
for metric_string in metrics:
try:
name, at = metric_string.split("@")
name, at = metric_string.split("_")
metrics_dict[name].append(int(at))
except (ValueError, TypeError):
raise MetricConfigError(
metric_string,
"Wrong formatting of metric in config. Expected format: <name>@<at> where name is valid metric name and at is and int")
"Wrong formatting of metric in config. Expected format: <name>_<at> where name is valid metric name and at is and int")
return metrics_dict


Expand Down
2 changes: 1 addition & 1 deletion allrank/config_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"early_stopping_patience": <int, number of epochs for early stopping>
"gradient_clipping_norm": <Optional[float]
},
"metrics": <list of metrics for evaluation of the form name@k where name is the name of the metric
"metrics": <list of metrics for evaluation of the form name_k where name is the name of the metric
as defined in models/metrics.py and k is the rank used for evaluation>,
"loss":
{
Expand Down
2 changes: 1 addition & 1 deletion allrank/training/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def compute_metrics(metrics, model, dl, dev):
metric_func = getattr(metrics_module, metric_name)
metric_func_with_ats = partial(metric_func, ats=ats)
metrics_values = metric_on_epoch(metric_func_with_ats, model, dl, dev)
metrics_names = ["{metric_name}@{at}".format(metric_name=metric_name, at=at) for at in ats]
metrics_names = ["{metric_name}_{at}".format(metric_name=metric_name, at=at) for at in ats]
metric_values_dict.update(dict(zip(metrics_names, metrics_values)))

return metric_values_dict
Expand Down
6 changes: 3 additions & 3 deletions scripts/local_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
"early_stopping_patience": 100,
"gradient_clipping_norm": null
},
"val_metric": "ndcg@5",
"val_metric": "ndcg_5",
"metrics": [
"ndcg@5"
"ndcg_5"
],
"loss": {
"name": "ordinal",
Expand All @@ -57,7 +57,7 @@
},
"expected_metrics" : {
"val": {
"ndcg@5": 0.76
"ndcg_5": 0.76
}
}
}
6 changes: 3 additions & 3 deletions scripts/local_config_click_model.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
"early_stopping_patience": 100,
"gradient_clipping_norm": null
},
"val_metric": "ndcg@5",
"val_metric": "ndcg_5",
"metrics": [
"ndcg@5"
"ndcg_5"
],
"loss": {
"name": "ordinal",
Expand All @@ -57,7 +57,7 @@
},
"expected_metrics" : {
"val": {
"ndcg@5": 0.76
"ndcg_5": 0.76
}
},
"click_model": {
Expand Down
22 changes: 11 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
README = (HERE / "README.md").read_text()

reqs = [
"torch==1.4.0",
"torchvision==0.5.0",
"scikit-learn==0.22.1",
"pandas==0.25.3",
"numpy==1.18.1",
"scipy==1.4.1",
"attrs==19.3.0",
"flatten_dict==0.2.0",
"tensorboardX==2.0.0",
"gcsfs>=0.6.0"
"torch==1.6.0",
"torchvision==0.7.0",
"scikit-learn>=0.23.0",
"pandas>=1.0.5",
"numpy>=1.18.5",
"scipy>=1.4.1",
"attrs>=19.3.0",
"flatten_dict>=0.3.0",
"tensorboardX>=2.1.0",
"gcsfs==0.6.2"
]

setup(
name="allRank",
version="1.3.0",
version="1.3.1",
description="allRank is a framework for training learning-to-rank neural models",
long_description=README,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit c90b64a

Please sign in to comment.