Skip to content

Commit

Permalink
[BB3] Memory Heuristics (facebookresearch#4770)
Browse files Browse the repository at this point in the history
* memory heuristics

* small changes

* address comments

* fix config

* reqs
  • Loading branch information
klshuster authored Sep 9, 2022
1 parent fae3e39 commit 58b6977
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 140 deletions.
13 changes: 7 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ commands:
pip install -v -r requirements.txt
python setup.py develop
python -c "import nltk; nltk.download('punkt')"
python -c "import nltk; nltk.download('stopwords')"
installtorchgpu:
description: Install torch GPU and dependencies
Expand Down Expand Up @@ -207,26 +208,26 @@ commands:
- setupcuda
- fixgit
- restore_cache:
key: deps-20220819c-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20220831-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
- setup
- installdeps
- << parameters.more_installs >>
- save_cache:
key: deps-20220819c-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20220831-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
- findtests:
marker: << parameters.marker >>
- restore_cache:
key: data-20220819c-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20220831-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
- run:
name: Run tests
no_output_timeout: 60m
command: |
coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml
- save_cache:
key: data-20220819c-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20220831-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
paths:
- "~/ParlAI/data"
- codecov
Expand All @@ -243,12 +244,12 @@ commands:
- checkout
- fixgit
- restore_cache:
key: deps-20220819c-bw-{{ checksum "requirements.txt" }}
key: deps-20220831-bw-{{ checksum "requirements.txt" }}
- setup
- installdeps
- installtorchgpu
- save_cache:
key: deps-20220819c-bw-{{ checksum "requirements.txt" }}
key: deps-20220831-bw-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
Expand Down
7 changes: 5 additions & 2 deletions parlai/opt_presets/gen/opt_bb3.opt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
"mkm_penalize_repetitions": true,
"mkm_model": "projects.bb3.agents.opt_api_agent:BB3OPTAgent",
"mkm_server": "opt_server",
"ignore_in_session_memories_mkm": false,
"memory_overlap_threshold": 0,
"memory_hard_block_for_n_turns": 0,
"memory_soft_block_decay_factor": 0,

"contextual_knowledge_control_token": "",
"contextual_knowledge_decision": "compute",
Expand Down Expand Up @@ -192,6 +196,5 @@
"include_prompt": false,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
"all_vanilla_prompt": false
}
7 changes: 5 additions & 2 deletions parlai/opt_presets/gen/opt_pt.opt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
"mkm_penalize_repetitions": false,
"mkm_model": "projects.bb3.agents.opt_api_agent:BB3OPTAgent",
"mkm_server": "opt_server",
"ignore_in_session_memories_mkm": false,
"memory_overlap_threshold": 0,
"memory_hard_block_for_n_turns": 0,
"memory_soft_block_decay_factor": 0,

"contextual_knowledge_control_token": "",
"contextual_knowledge_decision": "compute",
Expand Down Expand Up @@ -192,6 +196,5 @@
"include_prompt": true,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
"all_vanilla_prompt": false
}
105 changes: 58 additions & 47 deletions projects/bb3/agents/opt_bb3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ def add_cmdline_args(
help='Number of times to retry on API request failures (< 0 for unlimited retry).',
)
parser.add_argument('--metaseq-server-timeout', default=20.0, type=float)
parser.add_argument(
'--ignore-in-session-memories-mkm',
type='bool',
default=False,
help='If true, we do not look at the in-session memories when '
'generating from the MKM',
)
return parser

def __init__(self, opt, shared=None):
Expand All @@ -177,12 +170,8 @@ def __init__(self, opt, shared=None):
self.agents[Module.SEARCH_KNOWLEDGE.agent_name()] = top_agent
self.agents[Module.SEARCH_KNOWLEDGE] = top_agent

self.dictionary = top_agent.dictionary
# continue
self.max_prompt_len = opt.get('max_prompt_len', PROMPT.MAX_PROMPT_LEN)
self.ignore_in_session_memories_mkm = opt.get(
'ignore_in_session_memories_mkm', False
)
self.search_agent = SearchAgent(
{
'server': self.opt.get('search_server', 'default'),
Expand Down Expand Up @@ -270,7 +259,7 @@ def get_mdm_observation(self, ag_obs: Message) -> Message:
return ag_obs

def get_orm_observation(
self, observation: Message, opening_memories: List[str]
self, observation: Message, opening_memories: Dict[str, int]
) -> Message:
"""
Return the appropriate ORM observation.
Expand All @@ -285,27 +274,65 @@ def get_orm_observation(
"""
agent = self.agents[Module.OPENING_DIALOGUE]
agent.reset()
for i, mem in enumerate(opening_memories):
prefixed_memories = {}
for mem, val in opening_memories.items():
mem = MemoryUtils.maybe_add_memory_prefix(mem, 'partner', self.MODEL_TYPE)
opening_memories[i] = mem
prefixed_memories[mem] = val

new_obs = copy.deepcopy(observation)
new_obs.force_set(
'text', self._check_and_limit_len('\n'.join(opening_memories))
memories_to_use = MemoryUtils.get_available_memories(
'',
prefixed_memories,
set(),
dictionary=self.dictionary,
**self._get_memory_heuristic_values(),
)
new_obs.force_set('memories', opening_memories)
if not memories_to_use:
# we need at least one memory to open with...
memories_to_use = random.sample(list(prefixed_memories.keys()), 1)

new_obs.force_set('text', self._check_and_limit_len('\n'.join(memories_to_use)))
new_obs.force_set('memories', prefixed_memories)

return agent.observe(new_obs)

def get_opening_memories(
self, memories: Optional[Union[List[str], Dict[str, int]]]
) -> Optional[Dict[str, int]]:
"""
Get the opening memories, if applicable.
This function is designed to handle legacy cases where memories are
presented as a list, rather than a dict.
:param memories:
memories from the opening message
:return opening_memories:
return the set of true opening memories
"""
opening_memories = None
if memories:
if isinstance(memories, dict):
opening_memories = memories
elif isinstance(memories, list):
opening_memories = {}
for mem in memories:
opening_memories = MemoryUtils.add_memory(mem, opening_memories)

assert not opening_memories or isinstance(opening_memories, dict)
return opening_memories

def observe(self, observation: Message) -> Dict[Module, Message]:
# handle passed memories as well
observation = Message(observation)
opening_memories = observation.pop('memories', None)
opening_memories = self.get_opening_memories(observation.pop('memories', None))
observations = super().observe(observation)
for m in Module.dialogue_modules():
ag_obs = copy.deepcopy(observation)
observations[m] = self.agents[m].observe(ag_obs)
if is_opener(observation['text'], opening_memories):
assert opening_memories
orm_obs = self.get_orm_observation(observation, opening_memories)
self.memories = orm_obs['memories']
observations[Module.OPENING_DIALOGUE] = orm_obs
Expand Down Expand Up @@ -425,10 +452,7 @@ def batch_act_knowledge(
for module in Module:
obs = all_obs[module]
if module is Module.MEMORY_KNOWLEDGE and i in memory_indices:
memories = MemoryUtils.maybe_reduce_memories(
all_obs['raw']['text'], available_memory[i], self.dictionary
)
memories = '\n'.join(memories)
memories = '\n'.join(available_memory[i])
new_prompt = self._check_and_limit_len(
obs['prompt'].replace(module.opt_pre_context_tok(), memories)
)
Expand Down Expand Up @@ -701,7 +725,15 @@ def _failed_messages(replies):
retries += 1
n_mems = [min(1, len(obs['memories']) // 3) for obs in opening_obs]
for i, o in enumerate(opening_obs):
o.force_set('memories', random.sample(o['memories'], n_mems[i]))
mem_indices = random.sample(range(len(o['memories'])), n_mems[i])
o.force_set(
'memories',
{
m: v
for i, (m, v) in enumerate(o['memories'].items())
if i in mem_indices
},
)
if _failed_messages(batch_act):
for reply in batch_act:
text = reply.pop('text')
Expand Down Expand Up @@ -782,15 +814,8 @@ def batch_act(
for _ in range(len(observations))
]
# Step 1: determine whether we're searching or accessing memory
all_memory = [o['raw']['memories'] for o in observations]
available_memory = [
MemoryUtils.get_available_memories(
o['raw']['memories'],
o['raw']['in_session_memories'],
self.ignore_in_session_memories_mkm,
)
for o in observations
]
all_memory: List[Dict[str, int]] = [o['raw']['memories'] for o in observations]
available_memory = self.get_available_memories(observations)

batch_reply_sdm, search_indices = self.batch_act_decision(
observations,
Expand Down Expand Up @@ -908,18 +933,4 @@ def self_observe(self, self_message: Message):
agent.self_observe(self_message)
if self.vanilla:
return
memory_key = Module.MEMORY_GENERATOR.message_name()
for person in ['self', 'partner']:
memory_candidate = self_message.get(f"{memory_key}_{person}")
if not memory_candidate:
continue
if MemoryUtils.is_valid_memory(
self.memories,
memory_candidate,
MemoryUtils.get_memory_prefix(person, self.MODEL_TYPE),
):
memory_to_add = MemoryUtils.add_memory_prefix(
memory_candidate, person, self.MODEL_TYPE
)
self.memories.append(memory_to_add)
self.in_session_memories.add(memory_to_add)
self.self_observe_memory(self_message)
Loading

0 comments on commit 58b6977

Please sign in to comment.