Skip to content

Commit

Permalink
make data building more platform-friendly (facebookresearch#47)
Browse files Browse the repository at this point in the history
use os.paths.join and stop using os.execute in build_data.py
  • Loading branch information
alexholdenmiller authored May 6, 2017
1 parent b92347f commit 9fc989a
Show file tree
Hide file tree
Showing 49 changed files with 499 additions and 443 deletions.
66 changes: 39 additions & 27 deletions parlai/core/build_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,71 @@
These can be replaced if your particular file system does not support them.
"""

import datetime
import os
import requests
import shutil
import wget

def built(path):
return os.path.isfile(path + "/.built")
"""Checks if '.built' flag has been set for that task."""
return os.path.isfile(os.path.join(path, '.built'))

def download(path, url):
s = ('cd "%s"' % path) + '; wget ' + url
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
"""Downloads file using `wget`."""
filename = wget.download(url, out=path)
print() # wget prints download status, without newline

def download_request(url, path, fname):
"""Downloads file using `requests`."""
with requests.Session() as session:
response = session.get(url, stream=True)
CHUNK_SIZE = 32768
with open(os.path.join(path, fname), 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
response.close()

def make_dir(path):
s = ('mkdir -p "%s"' % (path))
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
"""Makes the directory and any nonexistent parent directories."""
os.makedirs(path, exist_ok=True)

def mark_done(path):
s = ('date > "%s"/.built' % path)
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
"""Marks the path as done by adding a '.built' file with the current
timestamp.
"""
with open(os.path.join(path, '.built'), 'w') as write:
write.write(str(datetime.datetime.today()))

def move(path1, path2):
s = ('mv "%s" "%s"' % (path1, path2))
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
"""Renames the given file."""
shutil.move(path1, path2)

def remove_dir(path):
s = ('rm -rf "%s"' % (path))
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
"""Removes the given directory, if it exists."""
shutil.rmtree(path, ignore_errors=True)

def untar(path, fname, deleteTar=True):
"""Unpacks the given archive file to the same directory, then (by default)
deletes the archive file.
"""
print('unpacking ' + fname)
s = ('cd "%s"' % path) + ';' + 'tar xfz "%s"' % (path + fname)
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
# remove tar file
fullpath = os.path.join(path, fname)
shutil.unpack_archive(fullpath, path)
if deleteTar:
s = ('cd "%s"' % path) + ';' + 'rm "%s"' % (path + fname)
if os.system(s) != 0:
raise RuntimeError('failed: ' + s)
os.remove(fullpath)

def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def download_file_from_google_drive(gd_id, destination):
import requests

def download_from_google_drive(gd_id, destination):
"""Uses the requests package to download a file from Google Drive."""
URL = 'https://docs.google.com/uc?export=download'

session = requests.Session()
with requests.Session() as session:
response = session.get(URL, params={'id': gd_id}, stream=True)
token = _get_confirm_token(response)
Expand Down
16 changes: 5 additions & 11 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
def str2bool(value):
return value.lower() in ('yes', 'true', 't', '1', 'y')

def path(s):
# Add a trailing slash if its not there.
if s[-1] != '/':
s += '/'
return s

class ParlaiParser(object):
"""Pseudo-extension of argparse which sets a number of parameters for the
ParlAI framework. More options can be added specific to other modules by
Expand All @@ -43,15 +37,15 @@ def __init__(self, add_parlai_args=True, add_model_args=False):
def add_parlai_data_path(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
default_data_path = parlai_dir + '/data/'
default_data_path = os.path.join(parlai_dir, 'data')
self.parser.add_argument(
'-dp', '--datapath', default=default_data_path,
help='path to datasets, defaults to {parlai_dir}/data')

def add_mturk_log_path(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
default_log_path = parlai_dir + '/logs/mturk/'
default_log_path = os.path.join(parlai_dir, 'logs', 'mturk')
self.parser.add_argument(
'--mturk-log-path', default=default_log_path,
help='path to mturk logs, defaults to {parlai_dir}/logs/mturk')
Expand All @@ -60,7 +54,7 @@ def add_parlai_args(self):
parlai_dir = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
os.environ['PARLAI_HOME'] = parlai_dir
default_downloads_path = parlai_dir + '/downloads/'
default_downloads_path = os.path.join(parlai_dir, 'downloads')

self.parser.add_argument(
'-t', '--task',
Expand Down Expand Up @@ -103,10 +97,10 @@ def parse_args(self, args=None, print_args=True):
self.args = self.parser.parse_args(args=args)
self.opt = {k: v for k, v in vars(self.args).items() if v is not None}
if 'download_path' in self.opt:
self.opt['download_path'] = path(self.opt['download_path'])
self.opt['download_path'] = self.opt['download_path']
os.environ['PARLAI_DOWNPATH'] = self.opt['download_path']
if 'datapath' in self.opt:
self.opt['datapath'] = path(self.opt['datapath'])
self.opt['datapath'] = self.opt['datapath']
if print_args:
self.print_args()
return self.opt
Expand Down
56 changes: 29 additions & 27 deletions parlai/mturk/core/manage_hit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _send_new_message(json_api_endpoint_url, task_group_id, conversation_id, age
post_data_dict['text'] = message_text
if reward:
post_data_dict['reward'] = reward

request = requests.post(json_api_endpoint_url, data=json.dumps(post_data_dict))
return json.loads(request.json())

Expand Down Expand Up @@ -90,8 +90,8 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i

approval_index_url_template = html_api_endpoint_url + "?method_name=approval_index&task_group_id={{task_group_id}}&conversation_id=1&cur_agent_id={{cur_agent_id}}&requester_key="+requester_key_gt

worker_agent_id = task_config['worker_agent_id']
bot_agent_id = bot.getID()
worker_agent_id = task_config['worker_agent_id']
bot_agent_id = bot.getID()
cids = range(1, num_hits+1)
cid_map = {cid: i for i, cid in enumerate(cids)}
c_done_map = {cid: False for cid in cids}
Expand All @@ -115,12 +115,12 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i
logs[cid].append(response)
new_message = _send_new_message(
json_api_endpoint_url=json_api_endpoint_url,
task_group_id=task_group_id,
conversation_id=cid,
agent_id=bot_agent_id,
message_text=response.get('text', None),
task_group_id=task_group_id,
conversation_id=cid,
agent_id=bot_agent_id,
message_text=response.get('text', None),
reward=response.get('reward', None),
episode_done=response.get('episode_done', False),
episode_done=response.get('episode_done', False),
)
if new_message['message_id'] > last_message_id:
last_message_id = new_message['message_id']
Expand All @@ -132,13 +132,13 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i
while len(conversations_remaining) > 0:
ret = _get_new_messages(
json_api_endpoint_url=json_api_endpoint_url,
task_group_id=task_group_id,
after_message_id=last_message_id,
task_group_id=task_group_id,
after_message_id=last_message_id,
excluded_agent_id=bot_agent_id,
)
conversation_dict = ret['conversation_dict']
new_last_message_id = ret['last_message_id']

if new_last_message_id:
last_message_id = new_last_message_id

Expand Down Expand Up @@ -168,22 +168,22 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i
logs[conversation_id].append(response)
_send_new_message(
json_api_endpoint_url=json_api_endpoint_url,
task_group_id=task_group_id,
conversation_id=conversation_id,
agent_id=bot_agent_id,
message_text=response.get('text', None),
task_group_id=task_group_id,
conversation_id=conversation_id,
agent_id=bot_agent_id,
message_text=response.get('text', None),
reward=response.get('reward', None),
episode_done=response.get('episode_done', False),
episode_done=response.get('episode_done', False),
)

# We don't create new HITs until this point, so that the HIT page will always have the conversation fully populated.
if not hits_created:
print('Creating HITs...')
hit_type_id = create_hit_type(
hit_title=task_config['hit_title'],
hit_description=task_config['hit_description'] + ' (ID: ' + task_group_id + ')',
hit_keywords=task_config['hit_keywords'],
hit_reward=hit_reward,
hit_title=task_config['hit_title'],
hit_description=task_config['hit_description'] + ' (ID: ' + task_group_id + ')',
hit_keywords=task_config['hit_keywords'],
hit_reward=hit_reward,
is_sandbox=is_sandbox
)
mturk_chat_url = None
Expand All @@ -192,8 +192,8 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i
mturk_chat_url = html_api_endpoint_url + "?method_name=chat_index&task_group_id="+str(task_group_id)+"&conversation_id="+str(cid)+"&cur_agent_id="+str(worker_agent_id)
if not chat_page_only:
mturk_page_url = create_hit_with_hit_type(
page_url=mturk_chat_url,
hit_type_id=hit_type_id,
page_url=mturk_chat_url,
hit_type_id=hit_type_id,
is_sandbox=is_sandbox
)

Expand Down Expand Up @@ -230,11 +230,13 @@ def create_hits(opt, task_config, task_module_name, bot, num_hits, hit_reward, i
# Saving logs to file
# Log format: {conversation_id: [list of messages in the conversation]}
mturk_log_path = opt['mturk_log_path']
task_group_path = mturk_log_path + task_module_name + '_' + datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + '/'
task_group_path = os.path.join(mturk_log_path,
task_module_name + '_' +
datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))
os.makedirs(task_group_path)
with open(task_group_path+'approved.json', 'w') as file:
file.write(json.dumps(logs_approved))
with open(task_group_path+'rejected.json', 'w') as file:
file.write(json.dumps(logs_rejected))
with open(os.path.join(task_group_path, 'approved.json'), 'w') as fout:
fout.write(json.dumps(logs_approved))
with open(os.path.join(task_group_path, 'rejected.json'), 'w') as fout:
fout.write(json.dumps(logs_rejected))

print("All conversations are saved to "+opt['mturk_log_path']+" in JSON format.\n")
Loading

0 comments on commit 9fc989a

Please sign in to comment.