Skip to content

Commit

Permalink
Small fixes (facebookresearch#55)
Browse files Browse the repository at this point in the history
fixes from issues, especially adding __init__.py files and unit test for them
  • Loading branch information
alexholdenmiller authored May 8, 2017
1 parent c4eab57 commit da9f230
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 13 deletions.
4 changes: 4 additions & 0 deletions examples/memnn_luatorch_cpu/full_task_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def main():
argparser.add_argument('--num_its', default=100, type=int)
parlai_home = os.environ['PARLAI_HOME']
if '--remote-cmd' not in sys.argv:
if os.system('which luajit') != 0:
raise RuntimeError('Could not detect torch luajit installed: ' +
'please install torch from http://torch.ch ' +
'or manually set --remote-cmd for this example.')
sys.argv.append('--remote-cmd')
sys.argv.append('luajit {}/parlai/agents/'.format(parlai_home) +
'memnn_luatorch_cpu/memnn_zmq_parsed.lua')
Expand Down
5 changes: 5 additions & 0 deletions parlai/agents/ir_baseline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
21 changes: 10 additions & 11 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def str2bool(value):
return value.lower() in ('yes', 'true', 't', '1', 'y')


class ParlaiParser(object):
"""Pseudo-extension of argparse which sets a number of parameters for the
ParlAI framework. More options can be added specific to other modules by
Expand All @@ -25,6 +26,11 @@ class ParlaiParser(object):
def __init__(self, add_parlai_args=True, add_model_args=False):
self.parser = argparse.ArgumentParser(description='ParlAI parser.')
self.parser.register('type', 'bool', str2bool)

self.parlai_home = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
os.environ['PARLAI_HOME'] = self.parlai_home

if add_parlai_args:
self.add_parlai_args()
if add_model_args:
Expand All @@ -35,27 +41,19 @@ def __init__(self, add_parlai_args=True, add_model_args=False):
self.register = self.parser.register

def add_parlai_data_path(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
default_data_path = os.path.join(parlai_dir, 'data')
default_data_path = os.path.join(self.parlai_home , 'data')
self.parser.add_argument(
'-dp', '--datapath', default=default_data_path,
help='path to datasets, defaults to {parlai_dir}/data')

def add_mturk_log_path(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
default_log_path = os.path.join(parlai_dir, 'logs', 'mturk')
default_log_path = os.path.join(self.parlai_home , 'logs', 'mturk')
self.parser.add_argument(
'--mturk-log-path', default=default_log_path,
help='path to mturk logs, defaults to {parlai_dir}/logs/mturk')

def add_parlai_args(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
os.environ['PARLAI_HOME'] = parlai_dir
default_downloads_path = os.path.join(parlai_dir, 'downloads')

default_downloads_path = os.path.join(self.parlai_home, 'downloads')
self.parser.add_argument(
'-t', '--task',
help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"')
Expand Down Expand Up @@ -96,6 +94,7 @@ def parse_args(self, args=None, print_args=True):
"""
self.args = self.parser.parse_args(args=args)
self.opt = {k: v for k, v in vars(self.args).items() if v is not None}
self.opt['parlai_home'] = self.parlai_home
if 'download_path' in self.opt:
self.opt['download_path'] = self.opt['download_path']
os.environ['PARLAI_DOWNPATH'] = self.opt['download_path']
Expand Down
5 changes: 5 additions & 0 deletions parlai/mturk/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
5 changes: 5 additions & 0 deletions parlai/mturk/tasks/model_evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
5 changes: 5 additions & 0 deletions parlai/mturk/tasks/qa_data_collection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
wget
boto3
requests
nltk
numpy
pyzmq
wget
1 change: 0 additions & 1 deletion requirements_ext.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pytorch
pyzmq
regex
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.


from setuptools import setup, find_packages
import sys

if sys.version_info < (3,):
sys.exit('Sorry, Python3 is required for ParlAI.')

with open('README.md') as f:
readme = f.read()
Expand Down
23 changes: 23 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

import os
import unittest


class TestInit(unittest.TestCase):
"""Make sure the package is alive."""

def test_init_everywhere(self):
from parlai.core.params import ParlaiParser
opt = ParlaiParser().parse_args()
for root, subfolder, files in os.walk(os.path.join(opt['parlai_home'], 'parlai')):
if not root.endswith('__pycache__'):
assert '__init__.py' in files, 'Dir {} is missing __init__.py'.format(root)


if __name__ == '__main__':
unittest.main()

0 comments on commit da9f230

Please sign in to comment.