Skip to content

Commit

Permalink
Black ALL THE THINGS (facebookresearch#1802)
Browse files Browse the repository at this point in the history
* Add support for black.

* Trivial change to see if it's blacked.

* Also add a CI test.

* Don't force CI to install both.

* Updates.

* Slightly better output.

* And black this.

* Don't black parlai_internal

* Black all the things.

* Delete trailing whitespaces.

* Fix some mixed tabs/spaces.
  • Loading branch information
stephenroller authored Jun 20, 2019
1 parent 625b3db commit b994cec
Show file tree
Hide file tree
Showing 439 changed files with 24,158 additions and 16,071 deletions.
6 changes: 2 additions & 4 deletions .circleci/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def detect_gpu():
"""Check if we should run GPU tests."""
commit_msg = '[gpu]' in testing_utils.git_commit_messages()
test_changed = any(
'tests/nightly/gpu' in fn
for fn in testing_utils.git_changed_files()
'tests/nightly/gpu' in fn for fn in testing_utils.git_changed_files()
)
return commit_msg or test_changed

Expand All @@ -48,8 +47,7 @@ def detect_mturk():
"""Check if we should run mturk tests."""
commit_msg = '[mturk]' in testing_utils.git_commit_messages().lower()
mturk_changed = any(
'parlai/mturk' in fn
for fn in testing_utils.git_changed_files()
'parlai/mturk' in fn for fn in testing_utils.git_changed_files()
)
return commit_msg or mturk_changed

Expand Down
30 changes: 12 additions & 18 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
import sphinx_rtd_theme


extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.githubpages'
]
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.githubpages']

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down Expand Up @@ -128,15 +125,12 @@
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',

# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',

# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
Expand All @@ -145,20 +139,14 @@
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'ParlAI.tex', 'ParlAI Documentation',
'FAIR', 'manual'),
]
latex_documents = [(master_doc, 'ParlAI.tex', 'ParlAI Documentation', 'FAIR', 'manual')]


# -- Options for manual page output ---------------------------------------

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'parlai', 'ParlAI Documentation',
[author], 1)
]
man_pages = [(master_doc, 'parlai', 'ParlAI Documentation', [author], 1)]


# -- Options for Texinfo output -------------------------------------------
Expand All @@ -167,7 +155,13 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'ParlAI', 'ParlAI Documentation',
author, 'ParlAI', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
'ParlAI',
'ParlAI Documentation',
author,
'ParlAI',
'One line description of project.',
'Miscellaneous',
)
]
2 changes: 1 addition & 1 deletion docs/source/generate_task_READMEs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
display_name = task_dict.get('display_name', None)
task_detailed = task_dict.get('task', None)
if ':' in task_detailed:
task = task_detailed[0:task_detailed.find(':')]
task = task_detailed[0 : task_detailed.find(':')]
else:
task = task_detailed
tags = task_dict.get('tags', None)
Expand Down
5 changes: 2 additions & 3 deletions docs/source/generate_zoo_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
if 'example' in model:
example = model['example']
else:
example = (
"python -m parlai.scripts.eval_model --model {} --task {} -mf {}"
.format(model['agent'], model['task'], model['path'])
example = "python -m parlai.scripts.eval_model --model {} --task {} -mf {}".format(
model['agent'], model['task'], model['path']
)
result = model.get('result', '').strip().split("\n")
# strip leading whitespace from results
Expand Down
71 changes: 44 additions & 27 deletions parlai/agents/bert_classifier/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from parlai.agents.bert_ranker.helpers import (
BertWrapper,
get_bert_optimizer,
MODEL_PATH
MODEL_PATH,
)
from parlai.core.utils import load_opt_file
from parlai.core.torch_agent import History
Expand All @@ -18,11 +18,16 @@
from collections import deque
import os
import torch

try:
from pytorch_pretrained_bert import BertModel
except ImportError:
raise Exception(("BERT rankers needs pytorch-pretrained-BERT installed. \n "
"pip install pytorch-pretrained-bert"))
raise Exception(
(
"BERT rankers needs pytorch-pretrained-BERT installed. \n "
"pip install pytorch-pretrained-bert"
)
)


class BertClassifierHistory(History):
Expand All @@ -49,11 +54,13 @@ class BertClassifierAgent(TorchClassifierAgent):
"""
Classifier based on Hugging Face BERT implementation.
"""

def __init__(self, opt, shared=None):
# download pretrained models
download(opt['datapath'])
self.pretrained_path = os.path.join(opt['datapath'], 'models',
'bert_models', MODEL_PATH)
self.pretrained_path = os.path.join(
opt['datapath'], 'models', 'bert_models', MODEL_PATH
)
opt['pretrained_path'] = self.pretrained_path
self._upgrade_opt(opt)
self.add_cls_token = opt.get('add_cls_token', True)
Expand All @@ -68,21 +75,34 @@ def history_class(cls):
def add_cmdline_args(parser):
TorchClassifierAgent.add_cmdline_args(parser)
parser = parser.add_argument_group('BERT Classifier Arguments')
parser.add_argument('--type-optimization', type=str,
default='all_encoder_layers',
choices=['additional_layers', 'top_layer',
'top4_layers', 'all_encoder_layers',
'all'],
help='which part of the encoders do we optimize '
'(defaults to all layers)')
parser.add_argument('--add-cls-token', type='bool', default=True,
help='add [CLS] token to text vec')
parser.add_argument('--sep-last-utt', type='bool', default=False,
help='separate the last utterance into a different'
'segment with [SEP] token in between')
parser.set_defaults(
dict_maxexs=0, # skip building dictionary
parser.add_argument(
'--type-optimization',
type=str,
default='all_encoder_layers',
choices=[
'additional_layers',
'top_layer',
'top4_layers',
'all_encoder_layers',
'all',
],
help='which part of the encoders do we optimize '
'(defaults to all layers)',
)
parser.add_argument(
'--add-cls-token',
type='bool',
default=True,
help='add [CLS] token to text vec',
)
parser.add_argument(
'--sep-last-utt',
type='bool',
default=False,
help='separate the last utterance into a different'
'segment with [SEP] token in between',
)
parser.set_defaults(dict_maxexs=0) # skip building dictionary

@staticmethod
def dictionary_class():
Expand All @@ -95,23 +115,20 @@ def _upgrade_opt(self, opt):
old_opt = load_opt_file(model_opt)
if 'add_cls_token' not in old_opt:
# old model, make this default to False
warn_once(
'Old model: overriding `add_cls_token` to False.'
)
warn_once('Old model: overriding `add_cls_token` to False.')
opt['add_cls_token'] = False
return

def build_model(self):
num_classes = len(self.class_list)
self.model = BertWrapper(
BertModel.from_pretrained(self.pretrained_path),
num_classes
BertModel.from_pretrained(self.pretrained_path), num_classes
)

def init_optim(self, params, optim_states=None, saved_optim_type=None):
self.optimizer = get_bert_optimizer([self.model],
self.opt['type_optimization'],
self.opt['learningrate'])
self.optimizer = get_bert_optimizer(
[self.model], self.opt['type_optimization'], self.opt['learningrate']
)

def _set_text_vec(self, *args, **kwargs):
obs = super()._set_text_vec(*args, **kwargs)
Expand Down
17 changes: 11 additions & 6 deletions parlai/agents/bert_ranker/bert_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# LICENSE file in the root directory of this source tree.
from parlai.core.dict import DictionaryAgent
from parlai.zoo.bert.build import download

try:
from pytorch_pretrained_bert import BertTokenizer
except ImportError:
raise ImportError('BERT rankers needs pytorch-pretrained-BERT installed. \n '
'pip install pytorch-pretrained-bert')
raise ImportError(
'BERT rankers needs pytorch-pretrained-BERT installed. \n '
'pip install pytorch-pretrained-bert'
)

from .helpers import VOCAB_PATH

Expand All @@ -19,21 +22,23 @@
class BertDictionaryAgent(DictionaryAgent):
"""Allow to use the Torch Agent with the wordpiece dictionary of Hugging Face.
"""

def __init__(self, opt):
super().__init__(opt)
# initialize from vocab path
download(opt['datapath'])
vocab_path = os.path.join(opt['datapath'], 'models', 'bert_models',
VOCAB_PATH)
vocab_path = os.path.join(opt['datapath'], 'models', 'bert_models', VOCAB_PATH)
self.tokenizer = BertTokenizer.from_pretrained(vocab_path)

self.start_token = '[CLS]'
self.end_token = '[SEP]'
self.null_token = '[PAD]'
self.start_idx = self.tokenizer.convert_tokens_to_ids(['[CLS]'])[
0] # should be 101
0
] # should be 101
self.end_idx = self.tokenizer.convert_tokens_to_ids(['[SEP]'])[
0] # should be 102
0
] # should be 102
self.pad_idx = self.tokenizer.convert_tokens_to_ids(['[PAD]'])[0] # should be 0
# set tok2ind for special tokens
self.tok2ind[self.start_token] = self.start_idx
Expand Down
11 changes: 7 additions & 4 deletions parlai/agents/bert_ranker/bert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

class BertRankerAgent(TorchAgent):
"""Abstract parent class for all Bert Ranker agents."""

def __init__(self, opt, shared=None):
raise RuntimeError('You must specify which ranker to use. Choices: \n'
'-m bert_ranker/bi_encoder_ranker \n'
'-m bert_ranker/cross_encoder_ranker \n'
'-m bert_ranker/both_encoder_ranker')
raise RuntimeError(
'You must specify which ranker to use. Choices: \n'
'-m bert_ranker/bi_encoder_ranker \n'
'-m bert_ranker/cross_encoder_ranker \n'
'-m bert_ranker/both_encoder_ranker'
)
Loading

0 comments on commit b994cec

Please sign in to comment.