Skip to content

Commit

Permalink
removing none_throws from session-level recall RecMetric (pytorch#1134)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1134

See comment:

https://www.internalfb.com/diff/D44836701?dst_version_fbid=926606994977131&transaction_fbid=925560898753257

Reviewed By: lequytra, YLGH

Differential Revision: D45241562

fbshipit-source-id: 1292cd320ed94f82c3481892860a5aa550258062
  • Loading branch information
Mark Gluzman authored and facebook-github-bot committed Apr 25, 2023
1 parent 114bf35 commit 163b45f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
29 changes: 17 additions & 12 deletions torchrec/metrics/recall_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
from torch import distributed as dist
from torchrec.distributed.utils import none_throws
from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
Expand Down Expand Up @@ -149,7 +148,7 @@ def update(
or self.session_var_name not in kwargs["required_inputs"]
):
raise RecMetricException(
"Need the {} input to update the bucket metric".format(
"Need the {} input to update the session metric".format(
self.session_var_name
)
)
Expand Down Expand Up @@ -239,15 +238,13 @@ def __init__(
"Fused update is not supported for recall session-level metrics"
)
for task in tasks:
session_metric_def = none_throws(
task.session_metric_def, "Please, specify the session metric definition"
)
if session_metric_def.top_threshold is None:
raise RecMetricException("Please, specify the top threshold")
if session_metric_def.session_var_name is None:
if task.session_metric_def is None:
raise RecMetricException(
"Please, specify the session var name in your model output"
"Please, specify the session metric definition"
)
session_metric_def = task.session_metric_def
if session_metric_def.top_threshold is None:
raise RecMetricException("Please, specify the top threshold")

super().__init__(
world_size=world_size,
Expand All @@ -266,15 +263,23 @@ def _get_task_kwargs(
) -> Dict[str, Any]:
if isinstance(task_config, list):
raise RecMetricException("Session metric can only take one task at a time")
return {"session_metric_def": none_throws(task_config.session_metric_def)}

if task_config.session_metric_def is None:
raise RecMetricException("Please, specify the session metric definition")

return {"session_metric_def": task_config.session_metric_def}

def _get_task_required_inputs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Set[str]:
if isinstance(task_config, list):
raise RecMetricException("Session metric can only take one task at a time")

if task_config.session_metric_def is None:
raise RecMetricException("Please, specify the session metric definition")

return (
{none_throws(task_config.session_metric_def).session_var_name}
if none_throws(task_config.session_metric_def).session_var_name
{task_config.session_metric_def.session_var_name}
if task_config.session_metric_def.session_var_name
else set()
)
35 changes: 35 additions & 0 deletions torchrec/metrics/tests/test_recall_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import no_grad
from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef
from torchrec.metrics.rec_metric import RecMetricException

from torchrec.metrics.recall_session import RecallSessionMetric

Expand Down Expand Up @@ -206,6 +207,40 @@ def test_recall_session_with_no_positive_examples(self) -> None:
)
raise

def test_error_messages(self) -> None:

task_info1 = RecTaskInfo(
name="Task1",
label_name="label1",
prediction_name="prediction1",
weight_name="weight1",
)

task_info2 = RecTaskInfo(
name="Task2",
label_name="label2",
prediction_name="prediction2",
weight_name="weight2",
session_metric_def=SessionMetricDef(session_var_name="session"),
)

error_message1 = "Please, specify the session metric definition"
with self.assertRaisesRegex(RecMetricException, error_message1):
_ = RecallSessionMetric(
world_size=1,
my_rank=5,
batch_size=100,
tasks=[task_info1],
)
error_message2 = "Please, specify the top threshold"
with self.assertRaisesRegex(RecMetricException, error_message2):
_ = RecallSessionMetric(
world_size=1,
my_rank=5,
batch_size=100,
tasks=[task_info2],
)

def test_tasks_input_propagation(self) -> None:
task_info1 = RecTaskInfo(
name="Task1",
Expand Down

0 comments on commit 163b45f

Please sign in to comment.