Skip to content

Commit

Permalink
split file name fix
Browse files Browse the repository at this point in the history
  • Loading branch information
saransh-mehta committed Jun 10, 2020
1 parent 4c8d04f commit 3c07346
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def main():
print('Loading raw data for task {} from {}'.format(taskName, os.path.join(args.data_dir, file)))
rows = load_data(os.path.join(args.data_dir, file), tasks.taskTypeMap[taskName],
hasLabels = args.has_labels)
wrtFile = os.path.join(dataPath, '{}.json'.format(file.split('.')[0]))
#wrtFile = os.path.join(dataPath, '{}.json'.format(file.split('.')[0]))
wrtFile = os.path.join(dataPath, '{}.json'.format(file.lower().replace('.tsv', '')))
print('Processing Started...')
create_data_multithreaded(rows, wrtFile, tokenizer, tasks, taskName,
args.max_seq_len, args.multithreaded)
Expand Down
6 changes: 3 additions & 3 deletions examples/intent_ner_fragment/tasks_file_snips.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ fragdetect:
loss_type: CrossEntropyLoss
task_type: SingleSenClassification
file_names:
- fragment_snips_train.tsv
- fragment_snips_dev.tsv
- fragment_snips_test.tsv
- fragment_intent_snips_train.tsv
- fragment_intent_snips_dev.tsv
- fragment_intent_snips_test.tsv
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def make_data_handlers(taskParams, mode, isTrain, gpu):
taskType = taskParams.taskTypeMap[taskName]
if mode == "test":
assert len(taskParams.fileNamesMap[taskName])==3, "test file is required along with train, dev"
dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].split('.')[0])
#dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].split('.')[0])
dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].lower().replace('.tsv',''))
taskDataPath = os.path.join(args.data_dir, dataFileName)
assert os.path.exists(taskDataPath), "{} doesn't exist".format(taskDataPath)
taskDict = {"data_task_id" : int(taskId),
Expand Down

0 comments on commit 3c07346

Please sign in to comment.