Skip to content

Commit

Permalink
Remove support for use_chars_concat
Browse files Browse the repository at this point in the history
use_chars_lstm always performs better, and using ELMo is even better
  • Loading branch information
nikitakit committed Dec 20, 2018
1 parent 2cc0220 commit b200ac6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 39 deletions.
1 change: 0 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def make_hparams():
use_tags=False,
use_words=False,
use_chars_lstm=False,
use_chars_concat=False,
use_elmo=False,

d_char_emb=32, # A larger value may be better for use_chars_lstm
Expand Down
43 changes: 5 additions & 38 deletions src/parse_nk.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,33 +635,21 @@ def __init__(
self.use_tags = hparams.use_tags

self.morpho_emb_dropout = None
if hparams.use_chars_lstm or hparams.use_chars_concat or hparams.use_elmo:
if hparams.use_chars_lstm or hparams.use_elmo:
self.morpho_emb_dropout = hparams.morpho_emb_dropout
else:
assert self.emb_types, "Need at least one of: use_tags, use_words, use_chars_lstm, use_chars_concat, use_elmo"
assert self.emb_types, "Need at least one of: use_tags, use_words, use_chars_lstm, use_elmo"

self.char_encoder = None
self.char_embedding = None
self.elmo = None
if hparams.use_chars_lstm:
assert not hparams.use_chars_concat, "use_chars_lstm and use_chars_concat are mutually exclusive"
assert not hparams.use_elmo, "use_chars_lstm and use_elmo are mutually exclusive"
self.char_encoder = CharacterLSTM(
num_embeddings_map['chars'],
hparams.d_char_emb,
self.d_content,
char_dropout=hparams.char_lstm_input_dropout,
)
elif hparams.use_chars_concat:
assert not hparams.use_elmo, "use_chars_concat and use_elmo are mutually exclusive"
self.num_chars_flat = self.d_content // hparams.d_char_emb
assert self.num_chars_flat >= 2, "incompatible settings of d_model/partitioned and d_char_emb"
assert self.num_chars_flat == (self.d_content / hparams.d_char_emb), "d_char_emb does not evenly divide model size"

self.char_embedding = nn.Embedding(
num_embeddings_map['chars'],
hparams.d_char_emb,
)
elif hparams.use_elmo:
self.elmo = get_elmo_class()(
options_file="data/elmo_2x4096_512_2048cnn_2xhighway_options.json",
Expand Down Expand Up @@ -724,6 +712,8 @@ def model(self):
def from_spec(cls, spec, model):
spec = spec.copy()
hparams = spec['hparams']
if 'use_chars_concat' in hparams and hparams['use_chars_concat']:
raise NotImplementedError("Support for use_chars_concat has been removed")
if 'sentence_max_len' not in hparams:
hparams['sentence_max_len'] = 300
if 'use_elmo' not in hparams:
Expand Down Expand Up @@ -833,30 +823,7 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
assert i == packed_len

extra_content_annotations = self.char_encoder(char_idxs_encoder, word_lens_encoder, batch_idxs)
elif self.char_embedding is not None:
char_idxs_encoder = np.zeros((packed_len, self.num_chars_flat), dtype=int)

i = 0
for snum, sentence in enumerate(sentences):
for wordnum, (tag, word) in enumerate([(START, START)] + sentence + [(STOP, STOP)]):
if word == START:
char_idxs_encoder[i, :] = self.char_vocab.index(CHAR_START_SENTENCE)
elif word == STOP:
char_idxs_encoder[i, :] = self.char_vocab.index(CHAR_STOP_SENTENCE)
else:
word_chars = (([self.char_vocab.index(CHAR_START_WORD)] * self.num_chars_flat)
+ [self.char_vocab.index_or_unk(char, CHAR_UNK) for char in word]
+ ([self.char_vocab.index(CHAR_STOP_WORD)] * self.num_chars_flat))
char_idxs_encoder[i, :self.num_chars_flat//2] = word_chars[self.num_chars_flat:self.num_chars_flat + self.num_chars_flat//2]
char_idxs_encoder[i, self.num_chars_flat//2:] = word_chars[::-1][self.num_chars_flat:self.num_chars_flat + self.num_chars_flat//2]
i += 1
assert i == packed_len

char_idxs_encoder = from_numpy(char_idxs_encoder)

extra_content_annotations = self.char_embedding(char_idxs_encoder)
extra_content_annotations = extra_content_annotations.view(-1, self.num_chars_flat * self.char_embedding.embedding_dim)
if self.elmo is not None:
elif self.elmo is not None:
# See https://github.com/allenai/allennlp/blob/c3c3549887a6b1fb0bc8abf77bc820a3ab97f788/allennlp/data/token_indexers/elmo_indexer.py#L61
# ELMO_START_SENTENCE = 256
# ELMO_STOP_SENTENCE = 257
Expand Down

0 comments on commit b200ac6

Please sign in to comment.