Skip to content

Commit

Permalink
Readd the script module changes (facebookresearch#3851)
Browse files Browse the repository at this point in the history
* Readd the script module changes

* Update parlai/torchscript/modules.py

Co-authored-by: Stephen Roller <[email protected]>

Co-authored-by: deankita <[email protected]>
Co-authored-by: Stephen Roller <[email protected]>
  • Loading branch information
3 people authored Jul 26, 2021
1 parent 5472513 commit 295082f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 52 deletions.
61 changes: 37 additions & 24 deletions parlai/scripts/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import importlib
from typing import List

import torch.jit
import torch.nn as nn
from packaging import version

from parlai.core.agents import create_agent
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
Expand All @@ -24,59 +24,72 @@ def export_model(opt: Opt):
Currently, only CPU greedy-search inference on BART models is supported.
"""

if version.parse(torch.__version__) < version.parse('1.7.0'):
if version.parse(torch.__version__) < version.parse("1.7.0"):
raise NotImplementedError(
'TorchScript export is only supported for Torch 1.7 and higher!'
"TorchScript export is only supported for Torch 1.7 and higher!"
)
else:
# Only load TorchScriptGreedySearch now, because this will trigger scripting of
# associated modules
from parlai.torchscript.modules import TorchScriptGreedySearch

overrides = {
'no_cuda': True, # TorchScripting is CPU only
'model_parallel': False, # model_parallel is not currently supported when TorchScripting
"no_cuda": True, # TorchScripting is CPU only
"model_parallel": False, # model_parallel is not currently supported when TorchScripting
}
if 'override' not in opt:
opt['override'] = {}
if opt.get("script_module"):
script_module_name, script_class_name = opt["script_module"].split(":", 1)
script_module = importlib.import_module(script_module_name)
script_class = getattr(script_module, script_class_name)
else:
script_class = TorchScriptGreedySearch
if "override" not in opt:
opt["override"] = {}
for k, v in overrides.items():
opt[k] = v
opt['override'][k] = v
opt["override"][k] = v

# Create the unscripted greedy-search module
agent = create_agent(opt, requireModelExists=True)
original_module = TorchScriptGreedySearch(agent)
original_module = script_class(agent)

# Script the module and save
scripted_module = torch.jit.script(TorchScriptGreedySearch(agent))
with PathManager.open(opt['scripted_model_file'], 'wb') as f:
scripted_module = torch.jit.script(script_class(agent))
with PathManager.open(opt["scripted_model_file"], "wb") as f:
torch.jit.save(scripted_module, f)

# Compare the original module to the scripted module against the test inputs
if len(opt['input']) > 0:
inputs = opt['input'].split('|')
print('\nGenerating given the original unscripted module:')
if len(opt["input"]) > 0:
inputs = opt["input"].split("|")
print("\nGenerating given the original unscripted module:")
_run_conversation(module=original_module, inputs=inputs)
print('\nGenerating given the scripted module:')
print("\nGenerating given the scripted module:")
_run_conversation(module=scripted_module, inputs=inputs)


def setup_args() -> ParlaiParser:
parser = ParlaiParser(add_parlai_args=True, add_model_args=True)
parser.add_argument(
'-smf',
'--scripted-model-file',
"-smf",
"--scripted-model-file",
type=str,
default='_scripted.pt',
help='Where the scripted model checkpoint will be saved',
default="_scripted.pt",
help="Where the scripted model checkpoint will be saved",
)
parser.add_argument(
"-in",
"--input",
type=str,
default='',
default="",
help="Input string to pass into the encoder of the scripted model, to test it against the unscripted version. Separate lines with a pipe",
)
parser.add_argument(
"-sm",
"--script-module",
type=str,
default="parlai.torchscript.modules:TorchScriptGreedySearch",
help="module to TorchScript. Example: parlai.torchscript.modules:TorchScriptGreedySearch",
)
return parser


Expand All @@ -86,14 +99,14 @@ def _run_conversation(module: nn.Module, inputs: List[str]):
"""
context = []
for input_ in inputs:
print(' TEXT: ' + input_)
print(" TEXT: " + input_)
context.append(input_)
label = module('\n'.join(context))
label = module("\n".join(context))
print("LABEL: " + label)
context.append(label)


@register_script('torchscript', hidden=True)
@register_script("torchscript", hidden=True)
class TorchScript(ParlaiScript):
@classmethod
def setup_args(cls):
Expand All @@ -103,5 +116,5 @@ def run(self):
return export_model(self.opt)


if __name__ == '__main__':
if __name__ == "__main__":
TorchScript.main()
66 changes: 38 additions & 28 deletions parlai/torchscript/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from typing import List, Dict, Optional, Tuple

import torch.jit
from torch import nn as nn

from parlai.core.dict import DictionaryAgent
from parlai.core.torch_agent import TorchAgent
from parlai.utils.bpe import Gpt2BpeHelper
from torch import nn as nn


class TorchScriptGreedySearch(nn.Module):
Expand All @@ -34,13 +33,13 @@ class TorchScriptGreedySearch(nn.Module):
"dict_lower": False,
"dict_textfields": "text,labels",
"dict_loaded": True,
'bpe_debug': False,
"bpe_debug": False,
}

def __init__(self, agent: TorchAgent):
super().__init__()

self.is_bart = agent.opt['model'] == 'bart'
self.is_bart = agent.opt["model"] == "bart"

# Dictionary/tokenization setup
for key, val in self.CAIRAOKE_DICT_PARAMS.items():
Expand All @@ -51,10 +50,10 @@ def __init__(self, agent: TorchAgent):
orig_bpe: Gpt2BpeHelper = orig_dict.bpe
assert all(len(key) == 2 for key in orig_bpe.bpe_ranks.keys())
assert not any(
i for key in orig_bpe.bpe_ranks.keys() for i in key if '\n' in i
i for key in orig_bpe.bpe_ranks.keys() for i in key if "\n" in i
), "We need to temporarily merge the bpe_ranks dict's keys with a newline character in order to use it as a TorchScript arg, but at least one of the dict's keys contains a newline character already!"
fused_key_bpe_ranks = {
'\n'.join(key): float(val) for key, val in orig_bpe.bpe_ranks.items()
"\n".join(key): float(val) for key, val in orig_bpe.bpe_ranks.items()
}
# Cast the values as floats to be able to compare to float('inf') when doing BPE
# splitting
Expand All @@ -66,7 +65,7 @@ def __init__(self, agent: TorchAgent):
freq=orig_dict.freq,
tok2ind=orig_dict.tok2ind,
ind2tok=orig_dict.ind2tok,
bpe_add_prefix_space=agent.opt['bpe_add_prefix_space'],
bpe_add_prefix_space=agent.opt["bpe_add_prefix_space"],
bpe_encoder=orig_bpe.encoder,
bpe_byte_encoder=orig_bpe.byte_encoder,
fused_key_bpe_ranks=fused_key_bpe_ranks,
Expand All @@ -75,12 +74,12 @@ def __init__(self, agent: TorchAgent):

# History tracking and start/end tokens
self.delimiter_tok = agent.history.delimiter_tok
self.history_size = agent.opt['history_size']
if agent.opt.get('history_add_global_end_token', None) is not None:
self.history_size = agent.opt["history_size"]
if agent.opt.get("history_add_global_end_token", None) is not None:
self.global_end_token = agent.dict[agent.dict.end_token]
else:
self.global_end_token = None
self.text_truncate = agent.opt.get('text_truncate') or agent.opt['truncate']
self.text_truncate = agent.opt.get("text_truncate") or agent.opt["truncate"]
self.text_truncate = self.text_truncate if self.text_truncate >= 0 else None

self.start_idx = agent.model.START_IDX
Expand Down Expand Up @@ -126,8 +125,8 @@ def __init__(self, agent: TorchAgent):
self.partially_traced_model = torch.jit.trace_module(
wrapped_model,
{
'output': (latent[:, -1:, :]),
'reorder_decoder_incremental_state': (
"output": (latent[:, -1:, :]),
"reorder_decoder_incremental_state": (
initial_incr_state,
torch.tensor([0], dtype=torch.long, device=sample_tokens.device),
),
Expand Down Expand Up @@ -172,7 +171,8 @@ def forward(self, context: str, max_len: int = 128) -> str:

# Vectorize all lines of context
history_vecs: List[List[int]] = []
context_lines = context.split('\n')
context_lines = context.split("\n")
context_lines = self.preprocess_context(context_lines)
if self.history_size > 0:
context_lines = context_lines[-self.history_size :]
for line in context_lines:
Expand Down Expand Up @@ -250,8 +250,18 @@ def forward(self, context: str, max_len: int = 128) -> str:
generation_tokens: List[int] = generations[0].tolist()
label = self._v2t(generation_tokens)

return self.postprocess_output_generations(label=label)

def postprocess_output_generations(self, label: str) -> str:
"""
Post-process the model output.
Returns the model output by default, override to add custom logic
"""
return label

def preprocess_context(self, context_lines: List[str]) -> List[str]:
return context_lines


class BaseIncrStateFlattener(nn.Module):
"""
Expand Down Expand Up @@ -286,7 +296,7 @@ def _unflatten_incr_state(
"""
structured_incr_state = defaultdict(lambda: defaultdict(dict))
for key, state in flat_incr_state.items():
layer_idx_str, attn_type, state_type = key.split('__')
layer_idx_str, attn_type, state_type = key.split("__")
structured_incr_state[int(layer_idx_str)][attn_type][state_type] = state
return dict({k: dict(v) for k, v in structured_incr_state.items()})
# Turn the nested defaultdicts back into regular dicts
Expand All @@ -304,7 +314,7 @@ def _flatten_incr_state(
for layer_idx, dict1 in structured_incr_state.items():
for attn_type, dict2 in dict1.items():
for state_type, state in dict2.items():
key = f'{layer_idx:d}__{attn_type}__{state_type}'
key = f"{layer_idx:d}__{attn_type}__{state_type}"
flat_incr_state[key] = state
return flat_incr_state

Expand Down Expand Up @@ -367,15 +377,15 @@ def findall(cls, text: str) -> List[str]:
"""
Split tokens in a manner that replicates parlai.utils.bpe.Gpt2BpeHelper.
"""
contraction_endings = ['s', 't', 're', 've', 'm', 'll', 'd']
contraction_endings = ["s", "t", "re", "ve", "m", "ll", "d"]

tokens: List[str] = []
idx = 0
num_passes = 0
while idx < len(text):
num_passes += 1
if num_passes > 10000:
return ['*** Infinite loop in ScriptableGpt2BpeHelper.findall()! ***']
return ["*** Infinite loop in ScriptableGpt2BpeHelper.findall()! ***"]
if text[idx] == "'":
# Capture contradiction suffixes
captured_suffix = False
Expand All @@ -388,10 +398,10 @@ def findall(cls, text: str) -> List[str]:
if captured_suffix:
continue
if not text[idx].isspace() or (
text[idx] == ' ' and idx + 1 < len(text) and not text[idx + 1].isspace()
text[idx] == " " and idx + 1 < len(text) and not text[idx + 1].isspace()
):
# Capture runs of one type of character
if text[idx] == ' ':
if text[idx] == " ":
last_matching_idx = idx + 1
else:
last_matching_idx = idx
Expand Down Expand Up @@ -487,7 +497,7 @@ def encode(self, text: str) -> List[str]:
A list of tokens
"""
if self.add_prefix_space:
text = f' {text}'
text = f" {text}"

# constants for readability
FINAL = 1
Expand Down Expand Up @@ -520,7 +530,7 @@ def encode(self, text: str) -> List[str]:
output.append(piece)
else:
output += self.helper_encode(piece)
text = ''.join(output)
text = "".join(output)

return output

Expand Down Expand Up @@ -559,14 +569,14 @@ def bpe(self, word: List[str]) -> List[str]:
return word

while True:
min_rank = self.bpe_ranks.get('\n'.join(pairs[0]), float('inf'))
min_rank = self.bpe_ranks.get("\n".join(pairs[0]), float("inf"))
bigram = pairs[0]
for pair in pairs[1:]:
current_rank = self.bpe_ranks.get('\n'.join(pair), float('inf'))
current_rank = self.bpe_ranks.get("\n".join(pair), float("inf"))
if current_rank < min_rank:
min_rank = current_rank
bigram = pair
if '\n'.join(bigram) not in self.bpe_ranks:
if "\n".join(bigram) not in self.bpe_ranks:
break
first, second = bigram
new_word: List[str] = []
Expand Down Expand Up @@ -640,10 +650,10 @@ def decode(self, tokens: List[str]) -> str:
if len(accum) > 0:
output.append(self.helper_decode(accum))

text = ''.join(output)
text = "".join(output)
if self.add_prefix_space:
assert text.startswith(' ')
text = text.lstrip(' ')
assert text.startswith(" ")
text = text.lstrip(" ")
return text

def helper_decode(self, tokens: List[str]) -> str:
Expand Down Expand Up @@ -672,7 +682,7 @@ def helper_decode(self, tokens: List[str]) -> str:
decoded_chars: List[str] = []
for char in chars:
decoded_chars.append(chr(self.byte_decoder[char]))
return ''.join(decoded_chars)
return "".join(decoded_chars)

def utf8_chars(self, s: str) -> List[str]:
"""
Expand Down

0 comments on commit 295082f

Please sign in to comment.