Skip to content

Commit

Permalink
Add support for PathManager. (facebookresearch#3011)
Browse files Browse the repository at this point in the history
* Add support for PathManager.

* Try fixing twitter.

* Fix.

* Lint.

* Pickle needs binary mode.

* Re-black.

* Fixed missed lines from twitter fixes. lol
  • Loading branch information
stephenroller authored Aug 28, 2020
1 parent 00efcbe commit 8200396
Show file tree
Hide file tree
Showing 164 changed files with 938 additions and 663 deletions.
6 changes: 4 additions & 2 deletions parlai/agents/bart/convert_fairseq_to_parlai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.script import ParlaiScript
from parlai.utils.io import PathManager


TRANSFORMER_PARAMETER_MAPPING = {
'attention_heads': 'n_heads',
Expand Down Expand Up @@ -240,7 +242,7 @@ def _load_single_fairseq_checkpoint(self, path: str) -> Dict[str, Any]:
:return state:
loaded fairseq state
"""
with open(path, "rb") as f:
with PathManager.open(path, "rb") as f:
try:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, "cpu")
Expand Down Expand Up @@ -397,7 +399,7 @@ def convert_model_weight(self, opt: Opt) -> Dict[str, Any]:
# 6. Shuffle embedding matrix given dictionary.
enc_emb_key = 'encoder.embeddings.weight'
bart_dict = os.path.join(opt['datapath'], 'models/bart/bart.large/dict.txt')
with open(bart_dict) as f:
with PathManager.open(bart_dict) as f:
offset_dict = {i: l.split()[0] for i, l in enumerate(f.readlines())}
new_embs = return_dict[enc_emb_key].clone()
for idx, new_idx in offset_dict.items():
Expand Down
3 changes: 2 additions & 1 deletion parlai/agents/bert_ranker/bi_encoder_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from parlai.core.torch_ranker_agent import TorchRankerAgent
from parlai.utils.torch import padded_3d
from parlai.zoo.bert.build import download
from parlai.utils.io import PathManager

from .bert_dictionary import BertDictionaryAgent
from .helpers import (
Expand Down Expand Up @@ -101,7 +102,7 @@ def set_vocab_candidates(self, shared):
"".format(len(self.vocab_candidates))
)
enc_path = self.opt.get('model_file') + '.vocab.encs'
if os.path.isfile(enc_path):
if PathManager.exists(enc_path):
self.vocab_candidate_encs = self.load_candidates(
enc_path, cand_type='vocab encodings'
)
Expand Down
10 changes: 5 additions & 5 deletions parlai/agents/drqa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from parlai.core.build_data import modelzoo_path
from parlai.utils.io import PathManager


def add_cmdline_args(parser):
Expand Down Expand Up @@ -161,10 +161,10 @@ def add_cmdline_args(parser):
def set_defaults(opt):
init_model = None
# check first for 'init_model' for loading model from file
if opt.get('init_model') and os.path.isfile(opt['init_model']):
if opt.get('init_model') and PathManager.exists(opt['init_model']):
init_model = opt['init_model']
# next check for 'model_file', this would override init_model
if opt.get('model_file') and os.path.isfile(opt['model_file']):
if opt.get('model_file') and PathManager.exists(opt['model_file']):
init_model = opt['model_file']

if init_model is None:
Expand All @@ -173,9 +173,9 @@ def set_defaults(opt):
opt.get('datapath'), opt['embedding_file']
)
if opt.get('embedding_file'):
if not os.path.isfile(opt['embedding_file']):
if not PathManager.exists(opt['embedding_file']):
raise IOError('No such file: %s' % opt['embedding_file'])
with open(opt['embedding_file']) as f:
with PathManager.open(opt['embedding_file']) as f:
dim = len(f.readline().strip().split(' ')) - 1
if dim == 1:
# first line was a dud
Expand Down
8 changes: 4 additions & 4 deletions parlai/agents/drqa/drqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
raise ImportError('Need to install pytorch: go to pytorch.org')

import bisect
import os
import numpy as np
import json
import random

from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from parlai.core.build_data import modelzoo_path
from parlai.utils.io import PathManager
from . import config
from .utils import build_feature_dict, vectorize, batchify, normalize_text
from .model import DocReaderModel
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, *args, **kwargs):
self.opt['embedding_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['embedding_file']
)
with open(self.opt['embedding_file']) as f:
with PathManager.open(self.opt['embedding_file']) as f:
for line in f:
w = normalize_text(line.rstrip().split(' ')[0])
self.embedding_words.add(w)
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self, opt, shared=None):
else:
# set up model
self.word_dict = DrqaAgent.dictionary_class()(opt)
if self.opt.get('model_file') and os.path.isfile(opt['model_file']):
if self.opt.get('model_file') and PathManager.exists(opt['model_file']):
self._init_from_saved(opt['model_file'])
else:
if self.opt.get('init_model'):
Expand Down Expand Up @@ -274,7 +274,7 @@ def save(self, fname=None):
self.opt['trained'] = True
self.model.save(fname)
# save opt file
with open(fname + '.opt', 'w') as handle:
with PathManager.open(fname + '.opt', 'w') as handle:
json.dump(self.opt, handle)

# --------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion parlai/agents/drqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unicodedata
from collections import Counter

from parlai.utils.io import PathManager
from parlai.core.build_data import modelzoo_path


Expand All @@ -29,7 +30,7 @@ def load_embeddings(opt, word_dict):
# Fill in embeddings
if not opt.get('embedding_file'):
raise RuntimeError('Tried to load embeddings with no embedding file.')
with open(opt['embedding_file']) as f:
with PathManager.open(opt['embedding_file']) as f:
for line in f:
parsed = line.rstrip().split(' ')
if len(parsed) > 2:
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/ir_baseline/ir_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from parlai.utils.io import PathManager


class MaxPriorityQueue(Sequence):
Expand Down Expand Up @@ -330,9 +331,9 @@ def save(self, path=None):
self.dictionary.save(path + '.dict')
data = {}
data['opt'] = self.opt
with open(path, 'wb') as handle:
with PathManager.open(path, 'wb') as handle:
torch.save(data, handle)
with open(path + '.opt', 'w') as handle:
with PathManager.open(path + '.opt', 'w') as handle:
json.dump(self.opt, handle)

def load(self, fname):
Expand Down
12 changes: 6 additions & 6 deletions parlai/agents/starspace/starspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parlai.core.dict import DictionaryAgent
from parlai.utils.misc import maintain_dialog_history, load_cands
from parlai.core.torch_agent import TorchAgent
from parlai.utils.io import PathManager
from .modules import Starspace

import torch
Expand All @@ -20,7 +21,6 @@
from collections import deque

import copy
import os
import random
import json

Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(self, opt, shared=None):
print("[ creating StarspaceAgent ]")
# this is not a shared instance of this class, so do full init
if opt.get('model_file') and (
os.path.isfile(opt.get('model_file') + '.dict')
PathManager.exists(opt.get('model_file') + '.dict')
or (opt['dict_file'] is None)
):
# set default dict-file if not set
Expand All @@ -207,7 +207,7 @@ def __init__(self, opt, shared=None):
self.dict = DictionaryAgent(opt)

self.model = Starspace(opt, len(self.dict), self.dict)
if opt.get('model_file') and os.path.isfile(opt['model_file']):
if opt.get('model_file') and PathManager.exists(opt['model_file']):
self.load(opt['model_file'])
else:
self._init_embeddings()
Expand Down Expand Up @@ -434,7 +434,7 @@ def predict(self, xs, ys=None, cands=None, cands_txt=None, obs=None):
for c in negs:
print("neg: " + self.v2t(c.squeeze()))
print("---")
y = -torch.ones(xe.size(0))
y = -(torch.ones(xe.size(0)))
y[0] = 1
loss = self.criterion(xe, ye, y)
loss.backward()
Expand Down Expand Up @@ -585,9 +585,9 @@ def save(self, path=None):
data['model'] = self.model.state_dict()
data['optimizer'] = self.optimizer.state_dict()
data['opt'] = self.opt
with open(path, 'wb') as handle:
with PathManager.open(path, 'wb') as handle:
torch.save(data, handle)
with open(path + '.opt', 'w') as handle:
with PathManager.open(path + '.opt', 'w') as handle:
json.dump(self.opt, handle)

def load(self, path):
Expand Down
4 changes: 2 additions & 2 deletions parlai/agents/tfidf_retriever/build_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
"""

import sqlite3
import os

from tqdm import tqdm

from collections import deque
import random
from parlai.core.teachers import create_task_agent_from_taskname
import parlai.utils.logging as logging
from parlai.utils.io import PathManager

# ------------------------------------------------------------------------------
# Store corpus.
Expand All @@ -33,7 +33,7 @@ def store_contents(opt, task, save_path, context_length=-1, include_labels=True)
save_path: Path to output sqlite db.
num_workers: Number of parallel processes to use when reading docs.
"""
if os.path.isfile(save_path):
if PathManager.exists(save_path):
raise RuntimeError('%s already exists! Not overwriting.' % save_path)

logging.info('Reading into database...')
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/tfidf_retriever/tfidf_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from parlai.core.agents import Agent
from parlai.utils.io import PathManager
from parlai.utils.misc import AttrDict
from .doc_db import DocDB
from .tfidf_doc_ranker import TfidfDocRanker
Expand Down Expand Up @@ -218,9 +219,9 @@ def rebuild(self):

def save(self, path=None):
self.rebuild()
with open(self.opt['model_file'] + '.opt', 'w') as handle:
with PathManager.open(self.opt['model_file'] + '.opt', 'w') as handle:
json.dump(self.opt, handle)
with open(self.opt['model_file'], 'w') as f:
with PathManager.open(self.opt['model_file'], 'w') as f:
f.write('\n')

def train_act(self):
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/unigram/unigram.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from itertools import islice
from parlai.utils.io import PathManager


class UnigramAgent(Agent):
Expand Down Expand Up @@ -109,10 +110,10 @@ def save(self, path=None):
if not path:
return

with open(path, 'w') as f:
with PathManager.open(path, 'w') as f:
f.write(self.get_prediction() + '\n')

with open(path + '.opt', 'w') as f:
with PathManager.open(path + '.opt', 'w') as f:
json.dump(self.opt, f)

def load(self, path):
Expand Down
3 changes: 2 additions & 1 deletion parlai/chat_service/services/messenger/messenger_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from parlai.core.agents import create_agent
import parlai.chat_service.utils.logging as log_utils
import parlai.chat_service.utils.server as server_utils
from parlai.utils.io import PathManager
from parlai.chat_service.services.messenger.agents import MessengerAgent
from parlai.chat_service.core.socket import ChatServiceMessageSocket
from parlai.chat_service.services.messenger.message_sender import MessageSender
Expand Down Expand Up @@ -222,7 +223,7 @@ def get_app_token(self):
"""
if not self.opt.get('force_page_token'):
if not os.path.exists(os.path.expanduser('~/.parlai/')):
os.makedirs(os.path.expanduser('~/.parlai/'))
PathManager.mkdirs(os.path.expanduser('~/.parlai/'))
access_token_file_path = '~/.parlai/messenger_token'
expanded_file_path = os.path.expanduser(access_token_file_path)
if os.path.exists(expanded_file_path):
Expand Down
15 changes: 8 additions & 7 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
``MultiTaskTeacher``.
"""

import copy

from parlai.core.build_data import modelzoo_path
from parlai.core.loader import load_agent_module
from parlai.core.loader import register_agent # noqa: F401
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
import copy
import os
import parlai.utils.logging as logging
from parlai.utils.io import PathManager


NOCOPY_ARGS = [
Expand Down Expand Up @@ -205,7 +206,7 @@ def compare_init_model_opts(opt: Opt, curr_opt: Opt):
return
opt['init_model'] = modelzoo_path(opt['datapath'], opt['init_model'])
optfile = opt['init_model'] + '.opt'
if not os.path.isfile(optfile):
if not PathManager.exists(optfile):
return
init_model_opt = Opt.load(optfile)

Expand Down Expand Up @@ -294,7 +295,7 @@ def create_agent_from_opt_file(opt: Opt):
model_file = opt['model_file']
optfile = model_file + '.opt'

if not os.path.isfile(optfile):
if not PathManager.exists(optfile):
return None

opt_from_file = Opt.load(optfile)
Expand Down Expand Up @@ -328,12 +329,12 @@ def create_agent_from_opt_file(opt: Opt):
# update dict file path
if not opt_from_file.get('dict_file'):
opt_from_file['dict_file'] = model_file + '.dict'
elif opt_from_file.get('dict_file') and not os.path.isfile(
elif opt_from_file.get('dict_file') and not PathManager.exists(
opt_from_file['dict_file']
):
old_dict_file = opt_from_file['dict_file']
opt_from_file['dict_file'] = model_file + '.dict'
if not os.path.isfile(opt_from_file['dict_file']):
if not PathManager.exists(opt_from_file['dict_file']):
warn_once(
'WARNING: Neither the specified dict file ({}) nor the '
'`model_file`.dict file ({}) exists, check to make sure either '
Expand Down Expand Up @@ -384,7 +385,7 @@ def create_agent(opt: Opt, requireModelExists=False):

if opt.get('model_file'):
opt['model_file'] = modelzoo_path(opt.get('datapath'), opt['model_file'])
if requireModelExists and not os.path.isfile(opt['model_file']):
if requireModelExists and not PathManager.exists(opt['model_file']):
raise RuntimeError(
'WARNING: Model file does not exist, check to make '
'sure it is correct: {}'.format(opt['model_file'])
Expand Down
Loading

0 comments on commit 8200396

Please sign in to comment.