Skip to content

Commit

Permalink
[Safety tests] Initial commit (facebookresearch#3767)
Browse files Browse the repository at this point in the history
* safety bench

* update projects readme

* update integration test

* update readme

* update comment

* readme update
  • Loading branch information
Emily Dinan authored Jul 8, 2021
1 parent 36004c9 commit 831bcc7
Show file tree
Hide file tree
Showing 17 changed files with 1,505 additions and 1 deletion.
5 changes: 4 additions & 1 deletion projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ _Task & models for chitchat with a given persona._
- **Build-It Break-It Fix-It for Dialogue Safety** [[project]](https://parl.ai/projects/dialogue_safety/) [[paper]](https://arxiv.org/abs/1908.06083).
_Task and method for improving the detection of offensive language in the context of dialogue._

- **Anticipating Safety Issues in E2E Conversational AI** [[project]](https://parl.ai/projects/safety_bench/).
_Benchmarks for evaluating the safety of English-language dialogue models_

- **Multi-Dimensional Gender Bias Classification** [[project]](https://parl.ai/projects/md_gender/) [[paper]](https://arxiv.org/abs/2005.00614)
_Training fine-grained gender bias classifiers to identify gender bias in text._

Expand All @@ -83,7 +86,7 @@ _Task & models for chitchat with a given persona._

- **Wizard of Wikipedia** [[project]](http://parl.ai/projects/wizard_of_wikipedia/) [[paper]](https://openreview.net/forum?id=r1l73iRqKm).
_Knowledge-grounded open domain chitchat task & models._

- **Retrieval Augmentation Reduces Hallucination in Conversation** [[project]](http://parl.ai/projects/hallucination/) [[paper]](https://arxiv.org/abs/2104.07567). _Exploratory architectures that add retrieval mechanisms to dialogue models, reducing hallucination while maintaining conversational ability._

## Visually Grounded
Expand Down
56 changes: 56 additions & 0 deletions projects/safety_bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Safety Bench: Checks for Anticipating Safety Issues with E2E Conversational AI Models

A suite of dialogue safety unit tests and integration tests, in correspondence with the paper <TODO: PAPER LINK>

## Paper Information
TODO: fill me in

**Abstract:** TODO: fill me in


## Setting up the API
The tests require *only* implementing only the following API:
```
def get_response(self, input_text: str) -> str:
```
This function takes as input the dialogue history (`input_text`) and returns the dialogue model's response (as a string).

> NOTE: One does not need to implement a ParlAI agent to run these unit tests; the API only requires text in, text out.
One must add one's model wrapper to the folder `projects/safety_bench/model_wrappers` and register it via `@register_model_wrapper("model_name")` so that it is accessible on the command line.

## Unit Tests

The unit tests run automatically provided the above API access to the model.

Details on these tests can be found in Section 6 of the paper. We test both:
1. The model's ability to generate offensive language and
2. How the model responds to offensive language.

### Example commands

Run unit tests for the model `blenderbot_90M` and safe all logs to the folder `/tmp/blender90M`:
```
python projects/safety_bench/run_unit_tests.py --wrapper blenderbot_90M --log-folder /tmp/blender90M
```

Run unit tests for the model `gpt2_large` and safe all logs to the folder `/tmp/gpt2large`:
```
python projects/safety_bench/run_unit_tests.py -w gpt2_large --log-folder /tmp/gpt2large
```

## Integration Tests
Provided the same API access as described above, we provide tooling to make it easy to run the human safety evaluations on Mechanical Turk from [here](https://parl.ai/projects/safety_recipes/).

These tools prepare data as input for the Mechanical Task. Further instructions for setting up [Mephisto](https://github.com/facebookresearch/Mephisto) and running the task on Mechanical Turk are printed at the completion of the script.

### Example Commands
Prepare integration tests for the adversarial setting for the model `blenderbot_3B`:
```
python projects/safety_bench/prepare_integration_tests.py --wrapper blenderbot_3B --safety-setting adversarial
```

Prepare integration tests for the nonadversarial setting for the model `dialogpt_medium`:
```
python projects/safety_bench/prepare_integration_tests.py --wrapper dialogpt_medium --safety-setting nonadversarial
```
5 changes: 5 additions & 0 deletions projects/safety_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions projects/safety_bench/model_wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
34 changes: 34 additions & 0 deletions projects/safety_bench/model_wrappers/example_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Example wrapper which replies `hello` to every text.
"""
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


@register_model_wrapper("example_wrapper")
class ExampleWrapper:
"""
Example wrapper which replies `hello` to every text.
"""

def __init__(self):
# Do any initialization here, like loading the omdel
pass

def get_response(self, input_text: str) -> str:
"""
Takes dialogue history (string) as input, and returns the
model's response (string).
"""
# This is the only method you are required to implement.
# The input text is the corresponding input for the model.
# Be sure to reset the model's dialogue history before/after
# every call to `get_response`.

return (
"Hello"
) # In this example, we always respond 'Hello' regardless of the input
121 changes: 121 additions & 0 deletions projects/safety_bench/model_wrappers/gpt_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrappers for GPT models from HF (in ParlAI).
Available models include:
- GPT2 large
- DialoGPT medium
"""
from abc import ABC, abstractproperty
from typing import Dict

from parlai.core.agents import create_agent
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


class GPTWrapper(ABC):
"""
Base class wrapper for GPT wrapper
"""

def __init__(self):
# Load the model from the model zoo via ParlAI
opt = {
"skip_generation": False,
"interactive_mode": True,
"model": f"hugging_face/{self.model_name}",
"gpt2_size": self.model_size,
"add_special_tokens": False,
}
opt.update(self.additional_opts)
self.model = create_agent(opt)

@abstractproperty
def model_name(self) -> str:
# Return the path to the agent in the model zoo
return ""

@abstractproperty
def model_size(self) -> str:
# Return the requested model size
return ""

@abstractproperty
def additional_opts(self) -> Dict:
# Return any model specific opts
return {}

def get_response(self, input_text: str) -> str:
# In ParlAI, we use observe/act syntax to get a response from the model
# Please see the ParlAI docs for more info
self.model.observe({"text": input_text, "episode_done": True})
response = self.model.act()

return response.get("text")


@register_model_wrapper("dialogpt_medium")
class DialoGPTMediumWrapper(GPTWrapper):
@property
def model_name(self):
return "dialogpt"

@property
def model_size(self):
return "medium"

@property
def additional_opts(self):
return {
"beam_context_block_ngram": 3,
"beam_block_ngram": 3,
"beam_size": 10,
"inference": "beam",
"beam_min_length": 20,
"beam_block_full_context": False,
}


@register_model_wrapper("gpt2_large")
class GPT2LargeWrapper(GPTWrapper):
@property
def model_name(self):
return "gpt2"

@property
def model_size(self):
return "large"

@property
def additional_opts(self):
return {
"beam_context_block_ngram": 3,
"beam_block_ngram": 3,
"beam_size": 10,
"inference": "beam",
"beam_min_length": 20,
"beam_block_full_context": False,
}

def get_response(self, input_text: str) -> str:
# For GPT-2, we add punctuation and an extra newline if one does
# not exist, and then take the first line generated

if input_text.strip()[-1] not in ['.', '?', '!']:
input_text += "."

self.model.observe({"text": input_text + "\n", "episode_done": True})
response = self.model.act()
# split on newline
response_texts = response.get("text").split("\n")
for response_text in response_texts:
if response_text:
# return first non-empty string
return response_text

# produced only newlines or empty strings
return ""
70 changes: 70 additions & 0 deletions projects/safety_bench/model_wrappers/parlai_model_zoo_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrappers for ParlAI models in the model zoo.
Available models include:
- blenderbot_90M
- blenderbot_400Mdistill
- blenderbot_1Bdistill
- blenderbot_3B
"""
from abc import ABC, abstractproperty

from parlai.core.agents import create_agent_from_model_file
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


class ParlAIModelZooWrapper(ABC):
"""
Base class wrapper for ParlAI models in the ParlAI zoo.
"""

def __init__(self):
# Load the model from the model zoo via ParlAI
overrides = {"skip_generation": False, "interactive_mode": True}
self.model = create_agent_from_model_file(self.zoo_path, overrides)

@abstractproperty
def zoo_path(self):
# Return the path to the agent in the model zoo
pass

def get_response(self, input_text: str) -> str:
# In ParlAI, we use observe/act syntax to get a response from the model
# Please see the ParlAI docs for more info
self.model.observe({"text": input_text, "episode_done": True})
response = self.model.act()

return response.get("text")


@register_model_wrapper("blenderbot_90M")
class BlenderBot90MWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_90M/model"


@register_model_wrapper("blenderbot_400Mdistill")
class BlenderBot400MDistillWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_400Mdistill/model"


@register_model_wrapper("blenderbot_1Bdistill")
class BlenderBot1BDistillWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_1Bdistill/model"


@register_model_wrapper("blenderbot_3B")
class BlenderBot3BWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_3B/model"
Loading

0 comments on commit 831bcc7

Please sign in to comment.