Skip to content

Commit

Permalink
Fix some minor bugs and add instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Apr 2, 2020
1 parent 9006ea8 commit be44fce
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
Binary file removed EVALB/evalb
Binary file not shown.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This is a PyTorch implementation of the parser described in ["Rethinking Self-At

* Python 3.6 or higher.
* The Python package requirements can be installed through the `requirements.sh` file.
* Run `make` in ./EVALB.

## Pre-trained models

Expand Down
4 changes: 2 additions & 2 deletions best_parser_training_script.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES=6 python src_joint/main.py train \
python src_joint/main.py train \
--model-path-base models/joint_xlnet_clean_large_3_layers_no_resdrop_lambda \
--epochs 100 \
--use-xlnet \
Expand All @@ -20,4 +20,4 @@ CUDA_VISIBLE_DEVICES=6 python src_joint/main.py train \
--dep-dev-ptb-path data/ptb_dev_3.3.0.sd.clean \
--lal-d-kv 128 \
--lal-d-proj 128 \
--no-lal-resdrop
--no-lal-resdrop
16 changes: 8 additions & 8 deletions src_joint/KM_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
if use_cuda:
torch_t = torch.cuda
def from_numpy(ndarray):
return torch.from_numpy(ndarray).pin_memory().cuda(async=True)
return torch.from_numpy(ndarray).pin_memory().cuda(non_blocking=True)
else:
print("Not using CUDA!")
torch_t = torch
Expand Down Expand Up @@ -331,7 +331,7 @@ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=torch.uint8)
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=torch.bool)

for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
Expand Down Expand Up @@ -983,7 +983,7 @@ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=torch.uint8)
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=torch.bool)

for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
if self.q_as_matrix:
Expand Down Expand Up @@ -1638,7 +1638,7 @@ def parse_batch(self, sentences, golds=None):
features = all_encoder_layers[-1]

if self.encoder is not None:
features_packed = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1,
features_packed = features.masked_select(all_word_end_mask.to(torch.bool).unsqueeze(-1)).reshape(-1,
features.shape[
-1])

Expand Down Expand Up @@ -1736,7 +1736,7 @@ def parse_batch(self, sentences, golds=None):
# features = all_encoder_layers[-1]
features = transformer_outputs[0]

features_packed = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1,
features_packed = features.masked_select(all_word_end_mask.to(torch.bool).unsqueeze(-1)).reshape(-1,
features.shape[
-1])

Expand Down Expand Up @@ -1779,8 +1779,8 @@ def parse_batch(self, sentences, golds=None):
else:
assert self.bert is not None
features = self.project_bert(features)
fencepost_annotations_start = features.masked_select(all_word_start_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1, features.shape[-1])
fencepost_annotations_end = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1, features.shape[-1])
fencepost_annotations_start = features.masked_select(all_word_start_mask.to(torch.bool).unsqueeze(-1)).reshape(-1, features.shape[-1])
fencepost_annotations_end = features.masked_select(all_word_end_mask.to(torch.bool).unsqueeze(-1)).reshape(-1, features.shape[-1])

fp_startpoints = batch_idxs.boundaries_np[:-1]
fp_endpoints = batch_idxs.boundaries_np[1:] - 1
Expand Down Expand Up @@ -1976,4 +1976,4 @@ def make_tree():
tree_list = make_tree()
assert len(tree_list) == 1
tree = tree_list[0]
return tree, score
return tree, score
3 changes: 2 additions & 1 deletion test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env bash
python src_joint/main.py test \
--dataset ptb \
--eval-batch-size 8 \
--consttest-ptb-path data/23.auto.clean \
--deptest-ptb-path data/ptb_test_3.3.0.sd \
--deptest-ptb-path data/ptb_test_3.3.0.sd.clean \
--embedding-path data/glove.gz \
--model-path-base best_parser.pt

0 comments on commit be44fce

Please sign in to comment.