Skip to content

Commit

Permalink
Colors in terminal display (facebookresearch#2515)
Browse files Browse the repository at this point in the history
* prettify display_data

* lint

* updates

* lint

* lint

* add interactive to colorization

* lint

* flip verbose/brief

* tests

* fix tests

* lint & black

Co-authored-by: Kurt Shuster <[email protected]>
  • Loading branch information
jaseweston and klshuster authored Apr 1, 2020
1 parent a833704 commit 8076e00
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 25 deletions.
10 changes: 8 additions & 2 deletions parlai/agents/local_human/local_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parlai.core.agents import Agent
from parlai.core.message import Message
from parlai.utils.misc import display_messages, load_cands
from parlai.utils.strings import colorize


class LocalHumanAgent(Agent):
Expand Down Expand Up @@ -40,7 +41,12 @@ def __init__(self, opt, shared=None):
self.episodeDone = False
self.finished = False
self.fixedCands_txt = load_cands(self.opt.get('local_human_candidates_file'))
print("Enter [DONE] if you want to end the episode, [EXIT] to quit.\n")
print(
colorize(
"Enter [DONE] if you want to end the episode, [EXIT] to quit.",
'highlight',
)
)

def epoch_done(self):
return self.finished
Expand All @@ -57,7 +63,7 @@ def observe(self, msg):
def act(self):
reply = Message()
reply['id'] = self.getID()
reply_text = input("Enter Your Message: ")
reply_text = input(colorize("Enter Your Message:", 'field') + ' ')
reply_text = reply_text.replace('\\n', '\n')
if self.opt.get('single_turn', False):
reply_text += '[DONE]'
Expand Down
1 change: 1 addition & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def display(self):
ignore_fields=self.opt.get('display_ignore_fields', ''),
prettify=self.opt.get('display_prettify', False),
max_len=self.opt.get('max_display_len', 1000),
verbose=self.opt.get('display_verbose', False),
)

def episode_done(self):
Expand Down
36 changes: 34 additions & 2 deletions parlai/scripts/display_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from parlai.core.params import ParlaiParser
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.utils.strings import colorize

import random

Expand All @@ -32,22 +33,53 @@ def setup_args(parser=None):
parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10)
parser.add_argument('-mdl', '--max-display-len', type=int, default=1000)
parser.add_argument('--display-ignore-fields', type=str, default='agent_reply')
parser.add_argument(
'-v',
'--display-verbose',
default=False,
action='store_true',
help='If false, simple converational view, does not show other message fields.',
)

parser.set_defaults(datatype='train:stream')
return parser


def simple_display(opt, world, turn):
if opt['batchsize'] > 1:
raise RuntimeError('Simple view only support batchsize=1')
act = world.get_acts()[0]
if turn == 0:
text = (
" - - - NEW EPISODE: " + act.get('id', "[no agent id]") + " - - - "
)
print(colorize(text, 'highlight'))
text = act.get('text', '[no text field]')
print(colorize(text, 'text'))
labels = act.get('labels', act.get('eval_labels', ['[no labels field]']))
labels = '|'.join(labels)
print(' ' + colorize(labels, 'labels'))


def display_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)

# Show some example dialogs.
turn = 0
for _ in range(opt['num_examples']):
world.parley()

# NOTE: If you want to look at the data from here rather than calling
# world.display() you could access world.acts[0] directly
print(world.display() + '\n~~')
# world.display() you could access world.acts[0] directly, see simple_display above.
if opt['display_verbose']:
print(world.display() + '\n~~')
else:
simple_display(opt, world, turn)
turn += 1
if world.get_acts()[0]['episode_done']:
turn = 0

if world.epoch_done():
print('EPOCH DONE')
Expand Down
5 changes: 2 additions & 3 deletions parlai/scripts/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def interactive(opt, print_parser=None):

# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
human_agent = LocalHumanAgent(opt)
world = create_task(opt, [human_agent, agent])

if print_parser:
# Show arguments after loading model
print_parser.opt = agent.opt
print_parser.print_args()
human_agent = LocalHumanAgent(opt)
world = create_task(opt, [human_agent, agent])

# Show some example dialogs:
while True:
Expand Down
9 changes: 5 additions & 4 deletions parlai/tasks/empathetic_dialogues/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def __init__(self, opt, shared=None):
opt.get('train_experiencer_only', DEFAULT_TRAIN_EXPERIENCER_ONLY)
and base_datatype == 'train'
) or base_datatype != 'train'
print(
f'[EmpatheticDialoguesTeacher] Only use experiencer side? '
f'{self.experiencer_side_only}, datatype: {self.datatype}'
)
if not shared:
print(
f'[EmpatheticDialoguesTeacher] Only use experiencer side? '
f'{self.experiencer_side_only}, datatype: {self.datatype}'
)
self.remove_political_convos = opt.get(
'remove_political_convos', DEFAULT_REMOVE_POLITICAL_CONVOS
)
Expand Down
42 changes: 35 additions & 7 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json

from parlai.core.message import Message
from parlai.utils.strings import colorize

try:
import torch
Expand Down Expand Up @@ -710,7 +711,7 @@ def _ellipse(lst: List[str], max_display: int = 5, sep: str = '|') -> str:
choices = list(lst)
# insert the ellipsis if necessary
if max_display > 0 and len(choices) > max_display:
ellipsis = '...and {} more'.format(len(choices) - max_display)
ellipsis = '... ({} of {} shown)'.format(max_display, len(choices))
choices = choices[:max_display] + [ellipsis]
return sep.join(str(c) for c in choices)

Expand All @@ -720,6 +721,7 @@ def display_messages(
prettify: bool = False,
ignore_fields: str = '',
max_len: int = 1000,
verbose: bool = False,
) -> Optional[str]:
"""
Return a string describing the set of messages provided.
Expand Down Expand Up @@ -755,6 +757,10 @@ def _token_losses_line(
# We only display the first agent (typically the teacher) if we
# are ignoring the agent reply.
continue
agent_id = msg.get('id', '[no id field]')
if verbose:
lines.append(colorize('[id]:', 'field') + ' ' + colorize(agent_id, 'id'))

if msg.get('episode_done'):
episode_done = True
# Possibly indent the text (for the second speaker, if two).
Expand All @@ -766,27 +772,49 @@ def _token_losses_line(
lines.append(space + '[reward: {r}]'.format(r=msg['reward']))
for key in msg:
if key not in DISPLAY_MESSAGE_DEFAULT_FIELDS and key not in ignore_fields_:
field = colorize('[' + key + ']:', 'field')
if type(msg[key]) is list:
line = '[' + key + ']:\n ' + _ellipse(msg[key], sep='\n ')
value = _ellipse(msg[key], sep='\n ')
else:
line = '[' + key + ']: ' + clip_text(str(msg.get(key)), max_len)
value = clip_text(str(msg.get(key)), max_len)
line = field + ' ' + colorize(value, 'text2')
lines.append(space + line)
if type(msg.get('image')) in [str, torch.Tensor]:
lines.append(f'[ image ]: {msg["image"]}')
if msg.get('text', ''):
text = clip_text(msg['text'], max_len)
ID = '[' + msg['id'] + ']: ' if 'id' in msg else ''
lines.append(space + ID + text)
if index == 0:
style = 'bold_text'
else:
style = 'labels'
if verbose:
lines.append(
space + colorize('[text]:', 'field') + ' ' + colorize(text, style)
)
else:
lines.append(
space
+ colorize("[" + agent_id + "]:", 'field')
+ ' '
+ colorize(text, style)
)
for field in {'labels', 'eval_labels', 'label_candidates', 'text_candidates'}:
if msg.get(field) and field not in ignore_fields_:
lines.append('{}[{}: {}]'.format(space, field, _ellipse(msg[field])))
string = '{}{} {}'.format(
space,
colorize('[' + field + ']:', 'field'),
colorize(_ellipse(msg[field]), field),
)
lines.append(string)
# Handling this separately since we need to clean up the raw output before displaying.
token_loss_line = _token_losses_line(msg, ignore_fields_, space)
if token_loss_line:
lines.append(token_loss_line)

if episode_done:
lines.append('- - - - - - - - - - - - - - - - - - - - -')
lines.append(
colorize('- - - - - - - END OF EPISODE - - - - - - - - - -', 'highlight')
)

return '\n'.join(lines)

Expand Down
34 changes: 34 additions & 0 deletions parlai/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
Utility functions and classes for handling text strings.
"""
import sys as _sys


def normalize_reply(text: str, version=1) -> str:
Expand Down Expand Up @@ -60,3 +61,36 @@ def uppercase(string: str) -> str:
return string
else:
return string[0].upper() + string[1:]


def colorize(text, style):
USE_COLORS = _sys.stdout.isatty()
BLUE = '\033[1;94m'
BOLD_LIGHT_GRAY = '\033[1;37;40m'
LIGHT_GRAY = '\033[0;37;40m'
MAGENTA = '\033[0;95m'
HIGHLIGHT_RED = '\033[1;37;41m'
HIGHLIGHT_BLUE = '\033[1;37;44m'
RESET = '\033[0;0m'
if not USE_COLORS:
return text
if style == 'highlight':
return HIGHLIGHT_RED + text + RESET
if style == 'highlight2':
return HIGHLIGHT_BLUE + text + RESET
elif style == 'text':
return LIGHT_GRAY + text + RESET
elif style == 'bold_text':
return BOLD_LIGHT_GRAY + text + RESET
elif style == 'labels' or style == 'eval_labels':
return BLUE + text + RESET
elif style == 'label_candidates':
return LIGHT_GRAY + text + RESET
elif style == 'id':
return LIGHT_GRAY + text + RESET
elif style == 'text2':
return MAGENTA + text + RESET
elif style == 'field':
return HIGHLIGHT_BLUE + text + RESET
else:
return MAGENTA + text + RESET
4 changes: 2 additions & 2 deletions tests/test_display_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def test_output(self):
Does display_data reach the end of the loop?
"""
str_output, _, _ = testing_utils.display_data(
{'num_examples': 1, 'task': 'babi:task1k:1'}
{'num_examples': 1, 'task': 'babi:task1k:1', 'display_verbose': True}
)

self.assertGreater(len(str_output), 0, "Output is empty")
self.assertIn("[babi:task1k:1]:", str_output, "Babi task did not print")
self.assertIn("babi:task1k:1", str_output, "Babi task did not print")
self.assertIn("~~", str_output, "Example output did not complete")


Expand Down
11 changes: 6 additions & 5 deletions tests/test_teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def _test_display_output(self, image_mode):
'task': 'integration_tests:ImageTeacher',
'datapath': data_path,
'image_mode': image_mode,
'display_verbose': True,
}
output = testing_utils.display_data(opt)
train_labels = re.findall(r"\[labels: .*\]", output[0])
valid_labels = re.findall(r"\[eval_labels: .*\]", output[1])
test_labels = re.findall(r"\[eval_labels: .*\]", output[2])
train_labels = re.findall(r"\[labels\].*\n", output[0])
valid_labels = re.findall(r"\[eval_labels\].*\n", output[1])
test_labels = re.findall(r"\[eval_labels\].*\n", output[2])

for i, lbls in enumerate([train_labels, valid_labels, test_labels]):
self.assertGreater(len(lbls), 0, 'DisplayData failed')
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_good_fileformat(self):
fp = os.path.join(tmpdir, "goodfile.txt")
with open(fp, "w") as f:
f.write('id:test_file\ttext:input\tlabels:good label\n\n')
opt = {'task': 'fromfile', 'fromfile_datapath': fp}
opt = {'task': 'fromfile', 'fromfile_datapath': fp, 'display_verbose': True}
testing_utils.display_data(opt)

def test_bad_fileformat(self):
Expand All @@ -77,7 +78,7 @@ def test_bad_fileformat(self):
fp = os.path.join(tmpdir, "badfile.txt")
with open(fp, "w") as f:
f.write('id:test_file\ttext:input\teval_labels:bad label\n\n')
opt = {'task': 'fromfile', 'fromfile_datapath': fp}
opt = {'task': 'fromfile', 'fromfile_datapath': fp, 'display_verbose': True}
with self.assertRaises(ValueError):
testing_utils.display_data(opt)

Expand Down

0 comments on commit 8076e00

Please sign in to comment.