Skip to content

Commit

Permalink
[RAG] Handle Different DPR Model File with Pre-trained Model (faceboo…
Browse files Browse the repository at this point in the history
…kresearch#3688)

* handle dpr model file correctly

* update comments

* different way of doing this

* fix import
  • Loading branch information
klshuster authored Jun 8, 2021
1 parent 4683821 commit 64d3859
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 1 deletion.
29 changes: 29 additions & 0 deletions parlai/agents/rag/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from parlai.agents.hugging_face.t5 import T5Agent
from parlai.agents.transformer.polyencoder import PolyencoderAgent
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.build_data import modelzoo_path
from parlai.core.dict import DictionaryAgent
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric
Expand Down Expand Up @@ -402,13 +403,41 @@ def update_state_dict(
)
return state_dict

def _should_override_dpr_model_weights(self, opt: Opt):
"""
Determine if we need to override the DPR Model weights.
Under certain circumstances, one may wish to specify a different
`--dpr-model-file` for a pre-trained, RAG model. Thus, we additionally
check to make sure that the loaded DPR model weights are not overwritten
by the state loading.
"""
override_dpr = False
overrides = opt.get('override', {})
if overrides.get('dpr_model_file') and os.path.exists(
overrides['dpr_model_file']
):
override_dpr = True
logging.warning(
f"Overriding DPR Model with {modelzoo_path(opt['datapath'], opt['dpr_model_file'])}"
)
return override_dpr

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
"""
Potentially update state dict with relevant RAG components.
Useful when initializing from a normal seq2seq model.
"""
try:
if self._should_override_dpr_model_weights(self.opt):
state_dict.update(
{
f"retriever.{k}": v
for k, v in self.model.retriever.state_dict().items() # type: ignore
}
)
super().load_state_dict(state_dict)
except RuntimeError:
state_dict = self.update_state_dict(self.opt, state_dict, self.model)
Expand Down
104 changes: 103 additions & 1 deletion tests/nightly/gpu/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import torch
import torch.cuda
from typing import Optional
import unittest

from parlai.core.build_data import modelzoo_path
from parlai.core.agents import create_agent
from parlai.core.params import ParlaiParser, Opt
import parlai.utils.testing as testing_utils

try:
import parlai.agents.rag.dpr # noqa: F401
from parlai.agents.rag.dpr import DprQueryEncoder
except ImportError:
pass

Expand Down Expand Up @@ -388,5 +392,103 @@ def test_bart_fid_rag_dpr_poly(self):
_test_zoo_file(FID_RAG_DPR_POLY_ZOO_MODEL, True, True)


class TestLoadDPRModel(unittest.TestCase):
"""
Test loading different DPR models for RAG.
Suppose we have the following models:
1. A: Default DPR Model
2. M: RAG Model trained with A
3. B: Resulting DPR Model after training M
4. C: DPR Model from training a different RAG Model
The following should hold true:
1. `parlai em -mf M` -> M.DPR (B) != A
2. `parlai em -mf M --dpr-model-file A` -> M.DPR == A
3. `parlai em -mf M --dpr-model-file C` -> M.DPR (C) != B
"""

def test_load_dpr(self):
opt = ParlaiParser(True, True).parse_args([])
# First, we'll load up a DPR model from the zoo dpr file.
default_query_encoder = DprQueryEncoder(
opt, dpr_model='bert', pretrained_path=DPR_ZOO_MODEL
)
rag_sequence_query_encoder = DprQueryEncoder(
opt,
dpr_model='bert_from_parlai_rag',
pretrained_path=RAG_SEQUENCE_ZOO_MODEL,
)
assert not torch.allclose(
default_query_encoder.embeddings.weight.float().cpu(),
rag_sequence_query_encoder.embeddings.weight.float().cpu(),
)
# 1. Create a zoo RAG Agent, which involves a trained DPR model
rag = create_agent(
Opt(
{
'model_file': modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
'override': {'retriever_debug_index': 'compressed', 'fp16': False},
}
)
)
# The default rag token model should have different query encoders
# from both the RAG_SEQUENCE_ZOO_MODEL, and the default DPR_ZOO_MODEL
assert not torch.allclose(
rag_sequence_query_encoder.embeddings.weight.float().cpu(),
rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
)
assert not torch.allclose(
default_query_encoder.embeddings.weight.float().cpu(),
rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
)

# 2. create a RAG Agent with the rag_sequence_zoo_model DPR model
rag = create_agent(
Opt(
{
'model_file': modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
'override': {
'retriever_debug_index': 'compressed',
'dpr_model_file': modelzoo_path(
opt['datapath'], RAG_SEQUENCE_ZOO_MODEL
),
'query_model': 'bert_from_parlai_rag',
'fp16': False,
},
}
)
)
# If we override the DPR Model file, we should now have the same
# weights as the query encoder from above.
assert torch.allclose(
rag_sequence_query_encoder.embeddings.weight.float().cpu(),
rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
)

# 3. Create a RAG Agent with the default DPR zoo model
rag = create_agent(
Opt(
{
'model_file': modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
'override': {
'retriever_debug_index': 'compressed',
'dpr_model_file': modelzoo_path(opt['datapath'], DPR_ZOO_MODEL),
'fp16': False,
},
}
)
)

# This model was trained with the DPR_ZOO_MODEL, and yet now should have the same weights
# as we explicitly specified it.
assert torch.allclose(
default_query_encoder.embeddings.weight.float().cpu(),
rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 64d3859

Please sign in to comment.