Skip to content

Commit

Permalink
show top docs in eval (facebookresearch#3931)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored Aug 12, 2021
1 parent d88cccc commit 70ee4a2
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion parlai/agents/rag/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import History, Batch
from parlai.core.torch_agent import History, Batch, Output
from parlai.core.torch_generator_agent import PPLMetric, TorchGeneratorAgent, TreeSearch
from parlai.utils.distributed import sync_parameters
from parlai.utils.io import PathManager
Expand Down Expand Up @@ -279,6 +279,15 @@ def observe(self, observation: Union[Dict, Message]) -> Message:
self._set_input_turn_cnt_vec(observation)
return observation

def eval_step(self, batch: Batch) -> Optional[Output]:
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
return output
assert isinstance(self.model, RagModel)
if hasattr(self.model.retriever, 'top_docs'):
output.top_docs = self.model.retriever.top_docs # type: ignore
return output

###### 1. Model Inputs ######

def _model_input(
Expand Down

0 comments on commit 70ee4a2

Please sign in to comment.