Skip to content

Commit

Permalink
Merge pull request facebookresearch#32 from facebookresearch/cuda_fix
Browse files Browse the repository at this point in the history
remove pin_memory for non-cuda compatibility
  • Loading branch information
ajfisch authored May 3, 2017
2 parents 53b400e + ae77cbd commit 8b1d1da
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions parlai/agents/drqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,18 @@ def batchify(batch, null=0):

# Batch documents and features
max_length = max([d.size(0) for d in docs])
x1 = torch.LongTensor(len(docs), max_length).fill_(null).pin_memory()
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1).pin_memory()
x1_f = torch.zeros(len(docs), max_length, features[0].size(1)).pin_memory()
x1 = torch.LongTensor(len(docs), max_length).fill_(null)
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
for i, d in enumerate(docs):
x1[i, :d.size(0)].copy_(d)
x1_mask[i, :d.size(0)].fill_(0)
x1_f[i, :d.size(0)].copy_(features[i])

# Batch questions
max_length = max([q.size(0) for q in questions])
x2 = torch.LongTensor(len(questions), max_length).fill_(null).pin_memory()
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1).pin_memory()
x2 = torch.LongTensor(len(questions), max_length).fill_(null)
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
for i, q in enumerate(questions):
x2[i, :q.size(0)].copy_(q)
x2_mask[i, :q.size(0)].fill_(0)
Expand Down

0 comments on commit 8b1d1da

Please sign in to comment.