Skip to content

Commit

Permalink
Added data filter to the BotAdversarialDialogueTeacher to remove repe…
Browse files Browse the repository at this point in the history
…ating "Hey do you want to talk about something else" episodes. (facebookresearch#4732)

* Added data filter to the BotAdversarialDialogueTeacher to remove repeating "Hey do you want to talk about something else" episodes.

* Updated flag description of BADTeacher.

* Turned the "do you want to talk about something else" filter for the BADTeacher into a mutator

Co-authored-by: Leonard Adolphs <[email protected]>
  • Loading branch information
leox1v and Leonard Adolphs authored Aug 15, 2022
1 parent 5637374 commit 6f10d0c
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion parlai/tasks/bot_adversarial_dialogue/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import Optional, List
from parlai.core.params import ParlaiParser
import os

from parlai.core.opt import Opt
from parlai.core.teachers import ParlAIDialogTeacher
from parlai.core.mutators import (
register_mutator,
ManyEpisodeMutator,
)
from parlai.tasks.bot_adversarial_dialogue.build import (
build_dialogue_datasets,
build_human_safety_eval_dataset,
Expand Down Expand Up @@ -280,3 +284,24 @@ def __init__(self, opt, shared=None):

class DefaultTeacher(BotAdversarialDialogueTeacher):
pass


@register_mutator('filter_want_to_talk_about_labels')
class FilterWantToTalkAboutLabelsMutator(ManyEpisodeMutator):
"""
Mutator that filters out episodes that end in an utterance asking 'do you want to
talk about ...'.
This accounts for roughly 7k episodes.
"""

def _filter_fn(self, message: Message) -> bool:
utterances = message['text'].split('\n')
return 'do you want to talk about' not in utterances[-1].lower()

def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]:
new_episodes = []
for message in episode:
if self._filter_fn(message):
new_episodes.append(message)
return [new_episodes]

0 comments on commit 6f10d0c

Please sign in to comment.