diff --git a/torchrec/metrics/recall_session.py b/torchrec/metrics/recall_session.py index d724c2993..49c57349f 100644 --- a/torchrec/metrics/recall_session.py +++ b/torchrec/metrics/recall_session.py @@ -10,7 +10,6 @@ import torch from torch import distributed as dist -from torchrec.distributed.utils import none_throws from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( @@ -149,7 +148,7 @@ def update( or self.session_var_name not in kwargs["required_inputs"] ): raise RecMetricException( - "Need the {} input to update the bucket metric".format( + "Need the {} input to update the session metric".format( self.session_var_name ) ) @@ -239,15 +238,13 @@ def __init__( "Fused update is not supported for recall session-level metrics" ) for task in tasks: - session_metric_def = none_throws( - task.session_metric_def, "Please, specify the session metric definition" - ) - if session_metric_def.top_threshold is None: - raise RecMetricException("Please, specify the top threshold") - if session_metric_def.session_var_name is None: + if task.session_metric_def is None: raise RecMetricException( - "Please, specify the session var name in your model output" + "Please, specify the session metric definition" ) + session_metric_def = task.session_metric_def + if session_metric_def.top_threshold is None: + raise RecMetricException("Please, specify the top threshold") super().__init__( world_size=world_size, @@ -266,15 +263,23 @@ def _get_task_kwargs( ) -> Dict[str, Any]: if isinstance(task_config, list): raise RecMetricException("Session metric can only take one task at a time") - return {"session_metric_def": none_throws(task_config.session_metric_def)} + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + + return {"session_metric_def": task_config.session_metric_def} def _get_task_required_inputs( self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] ) -> Set[str]: if isinstance(task_config, list): raise RecMetricException("Session metric can only take one task at a time") + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + return ( - {none_throws(task_config.session_metric_def).session_var_name} - if none_throws(task_config.session_metric_def).session_var_name + {task_config.session_metric_def.session_var_name} + if task_config.session_metric_def.session_var_name else set() ) diff --git a/torchrec/metrics/tests/test_recall_session.py b/torchrec/metrics/tests/test_recall_session.py index 8d6433203..af0c09401 100644 --- a/torchrec/metrics/tests/test_recall_session.py +++ b/torchrec/metrics/tests/test_recall_session.py @@ -11,6 +11,7 @@ import torch from torch import no_grad from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef +from torchrec.metrics.rec_metric import RecMetricException from torchrec.metrics.recall_session import RecallSessionMetric @@ -206,6 +207,40 @@ def test_recall_session_with_no_positive_examples(self) -> None: ) raise + def test_error_messages(self) -> None: + + task_info1 = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + ) + + task_info2 = RecTaskInfo( + name="Task2", + label_name="label2", + prediction_name="prediction2", + weight_name="weight2", + session_metric_def=SessionMetricDef(session_var_name="session"), + ) + + error_message1 = "Please, specify the session metric definition" + with self.assertRaisesRegex(RecMetricException, error_message1): + _ = RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info1], + ) + error_message2 = "Please, specify the top threshold" + with self.assertRaisesRegex(RecMetricException, error_message2): + _ = RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info2], + ) + def test_tasks_input_propagation(self) -> None: task_info1 = RecTaskInfo( name="Task1",