Skip to content

Commit

Permalink
[BB2] BlenderBot FiD agent that uses Wizard of Internet's original re…
Browse files Browse the repository at this point in the history
…trieved docs (facebookresearch#4163)

* BB2 retrieved gold agent

* empty retrieved docs

* updated the teacher tests
  • Loading branch information
mojtaba-komeili authored Nov 12, 2021
1 parent 74dfeeb commit 1874977
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 32 deletions.
33 changes: 19 additions & 14 deletions parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def get_retrieved_knowledge(self, message):
Extracts the retrieved knowledge from the message.
"""

def _set_query_vec(self, observation: Message) -> Message:
def show_observation_to_echo_retriever(self, observation: Message):
retrieved_docs = self.get_retrieved_knowledge(observation)
if len(retrieved_docs) > self._n_docs:
logging.warning(
Expand All @@ -365,6 +365,9 @@ def _set_query_vec(self, observation: Message) -> Message:
self.model_api.retriever.add_retrieve_doc(
observation[self._query_key], retrieved_docs
)

def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)


Expand Down Expand Up @@ -394,21 +397,23 @@ def get_retrieved_knowledge(self, message: Message):
n_docs_in_message = len(message[consts.RETRIEVED_DOCS])
already_added_doc_idx = []

if ' '.join(selected_sentences) != consts.NO_SELECTED_SENTENCES_TOKEN:
for doc_idx in range(n_docs_in_message):
doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
for sel_sentc in selected_sentences:
if sel_sentc in doc_content:
retrieved_docs.append(
self._extract_doc_from_message(message, doc_idx)
)
already_added_doc_idx.append(doc_idx)
break
if len(retrieved_docs) == self._n_docs:
logging.warning(
f'More than {self._n_docs} documents have selected sentences. Trimming them to the first {self._n_docs}'
if ' '.join(selected_sentences) == consts.NO_SELECTED_SENTENCES_TOKEN:
return retrieved_docs # `retrieved_docs` is empty at this point

for doc_idx in range(n_docs_in_message):
doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
for sel_sentc in selected_sentences:
if sel_sentc in doc_content:
retrieved_docs.append(
self._extract_doc_from_message(message, doc_idx)
)
already_added_doc_idx.append(doc_idx)
break
if len(retrieved_docs) == self._n_docs and doc_idx != (self._n_docs - 1):
logging.warning(
f'More than {self._n_docs} documents have selected sentences. Trimming them to the first {self._n_docs}'
)
break

# Then adding other (filler) docs.
# We add them by iterating forward in the __retrieved-docs__ list for repeatability,
Expand Down
5 changes: 4 additions & 1 deletion parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,14 +1283,17 @@ class ObservationEchoRetriever(RagRetriever):

def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None):
self._delimiter = '\n'
self.n_docs = opt['n_docs']
self._query_ids = dict()
self._saved_docs = dict()
super().__init__(opt, dictionary, shared=shared)

def add_retrieve_doc(self, query: str, retrieved_docs: List[Document]):
new_idx = len(self._query_ids)
self._query_ids[query] = new_idx
self._saved_docs[new_idx] = retrieved_docs or [BLANK_DOC]
self._saved_docs[new_idx] = retrieved_docs or [
BLANK_DOC for _ in range(self.n_docs)
]

def tokenize_query(self, query: str) -> List[int]:
return [self._query_ids[query]]
Expand Down
2 changes: 2 additions & 0 deletions parlai/tasks/wizard_of_internet/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ def get_knowledge(msg_d):
if not knowledge[CONST.RETRIEVED_DOCS]:
knowledge[CONST.RETRIEVED_DOCS] = [CONST.NO_RETRIEVED_DOCS_TOKEN]
knowledge[CONST.RETRIEVED_DOCS_URLS] = [CONST.NO_URLS]
knowledge[CONST.RETRIEVED_DOCS_TITLES] = [CONST.NO_TITLE]

if not knowledge[CONST.SELECTED_DOCS]:
knowledge[CONST.SELECTED_DOCS] = [CONST.NO_SELECTED_DOCS_TOKEN]
knowledge[CONST.SELECTED_SENTENCES] = [CONST.NO_SELECTED_SENTENCES_TOKEN]
knowledge[CONST.SELECTED_DOCS_TITLES] = [CONST.NO_TITLE]

return knowledge

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ acts:
Throwback video! When Salman Khan teased Katrina Kaif for missi...
Pulwama terror attack: Ram Gopal Varma taunts Pakistan Prime Mi...'
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3128,12 +3128,14 @@ acts:
\ i forgot to mention there is also only a 20 or thirty minute time limit, depending\
\ on the round. The cuisine is from around the world. Fine dining to street\
\ style. "
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
acts:
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand All @@ -17,12 +19,14 @@ acts:
search_query: __no_search_used__
text: "__knowledge__ __no_passages_used__ __endknowledge__ \n I work as a freelance\
\ accountant.\nI enjoy reading books. "
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down Expand Up @@ -786,7 +790,8 @@ acts:
Verification of Debentures | Guidelines for Auditors

Tags:Auditing, Role of Auditor'
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ acts:
Throwback video! When Salman Khan teased Katrina Kaif for missi...
Pulwama terror attack: Ram Gopal Varma taunts Pakistan Prime Mi...'
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3124,12 +3124,14 @@ acts:
text: 'Yes, i forgot to mention there is also only a 20 or thirty minute time
limit, depending on the round. The cuisine is from around the world. Fine dining
to street style. '
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
15 changes: 10 additions & 5 deletions parlai/tasks/wizard_of_internet/test/wizard_of_internet_valid.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
acts:
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand All @@ -18,12 +20,14 @@ acts:
text: 'I work as a freelance accountant.

I enjoy reading books. '
- - __retrieved-docs-titles__: []
- - __retrieved-docs-titles__:
- __no_title__
__retrieved-docs-urls__:
- __no_urls__
__retrieved-docs__:
- __noretrieved-docs__
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down Expand Up @@ -786,7 +790,8 @@ acts:
Verification of Debentures | Guidelines for Auditors

Tags:Auditing, Role of Auditor'
__select-docs-titles__: []
__select-docs-titles__:
- __no_title__
__selected-docs-urls__: []
__selected-docs__:
- __noselected-docs__
Expand Down
10 changes: 9 additions & 1 deletion projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn.functional as F
from typing import Union, Dict, List, Tuple, Optional, Any

from parlai.agents.fid.fid import FidAgent
from parlai.agents.fid.fid import FidAgent, WizIntGoldDocRetrieverFiDAgent
from parlai.agents.rag.args import DPR_ZOO_MODEL, QUERY_MODEL_TYPES
from parlai.agents.rag.rag import RagAgent
from parlai.agents.rag.model_types import (
Expand Down Expand Up @@ -892,3 +892,11 @@ def build_model(self) -> Union[BlenderBot2FidModel, T5BlenderBot2FidModel]:
model.encoder.embeddings.weight, self.opt['embedding_type']
)
return model


class BlenderBot2WizIntGoldDocRetrieverFiDAgent(
WizIntGoldDocRetrieverFiDAgent, BlenderBot2FidAgent
):
def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)
9 changes: 9 additions & 0 deletions projects/blenderbot2/agents/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RagRetrieverTokenizer,
SearchQuerySearchEngineRetriever,
SearchQueryFAISSIndexRetriever,
ObservationEchoRetriever,
)
from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt
Expand Down Expand Up @@ -64,6 +65,8 @@ def retriever_factory(
return BB2SearchQuerySearchEngineRetriever(opt, dictionary, shared=shared)
elif retriever is RetrieverType.SEARCH_TERM_FAISS:
return BB2SearchQueryFaissIndexRetriever(opt, dictionary, shared=shared)
elif retriever is RetrieverType.OBSERVATION_ECHO_RETRIEVER:
return BB2ObservationEchoRetriever(opt, dictionary, shared=shared)
else:
return rag_retriever_factory(opt, dictionary, shared=shared)

Expand Down Expand Up @@ -908,3 +911,9 @@ class BB2SearchQueryFaissIndexRetriever(
"""
Override Search Engine Retriever to accommodate SQ Generator from BB2 Setup.
"""


class BB2ObservationEchoRetriever(BB2SearchRetrieverMixin, ObservationEchoRetriever):
"""
A retriever that reads retrieved docs as part of the observed example message.
"""

0 comments on commit 1874977

Please sign in to comment.