Skip to content

Commit 2bb1a73

Browse files
committed
add DocumentCNNEmbeddings
1 parent 69d7c7d commit 2bb1a73

File tree

4 files changed

+179
-1
lines changed

4 files changed

+179
-1
lines changed

flair/embeddings/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .document import DocumentTFIDFEmbeddings
2727
from .document import DocumentRNNEmbeddings
2828
from .document import DocumentLMEmbeddings
29+
from .document import DocumentCNNEmbeddings
2930
from .document import SentenceTransformerDocumentEmbeddings
3031

3132
# Expose image embedding classes

flair/embeddings/document.py

+159
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,162 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
688688
def embedding_length(self) -> int:
689689
"""Returns the length of the embedding vector."""
690690
return self.model.get_sentence_embedding_dimension()
691+
692+
693+
class DocumentCNNEmbeddings(DocumentEmbeddings):
694+
def __init__(
695+
self,
696+
embeddings: List[TokenEmbeddings],
697+
kernels=((100, 3), (100, 4), (100, 5)),
698+
reproject_words: bool = True,
699+
reproject_words_dimension: int = None,
700+
dropout: float = 0.5,
701+
word_dropout: float = 0.0,
702+
locked_dropout: float = 0.0,
703+
fine_tune: bool = True,
704+
):
705+
"""The constructor takes a list of embeddings to be combined.
706+
:param embeddings: a list of token embeddings
707+
:param kernels: list of (number of kernels, kernel size)
708+
:param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
709+
layer before putting them into the rnn or not
710+
:param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
711+
dimension as before will be taken.
712+
:param dropout: the dropout value to be used
713+
:param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
714+
:param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
715+
"""
716+
super().__init__()
717+
718+
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
719+
self.length_of_all_token_embeddings: int = self.embeddings.embedding_length
720+
721+
self.kernels = kernels
722+
self.reproject_words = reproject_words
723+
724+
self.static_embeddings = False if fine_tune else True
725+
726+
self.embeddings_dimension: int = self.length_of_all_token_embeddings
727+
if self.reproject_words and reproject_words_dimension is not None:
728+
self.embeddings_dimension = reproject_words_dimension
729+
730+
self.word_reprojection_map = torch.nn.Linear(
731+
self.length_of_all_token_embeddings, self.embeddings_dimension
732+
)
733+
734+
# CNN
735+
self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels])
736+
self.convs = torch.nn.ModuleList(
737+
[
738+
torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in self.kernels
739+
]
740+
)
741+
self.pool = torch.nn.AdaptiveMaxPool1d(1)
742+
743+
self.name = "document_cnn"
744+
745+
# dropouts
746+
self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
747+
self.locked_dropout = (
748+
LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
749+
)
750+
self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None
751+
752+
torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)
753+
754+
self.to(flair.device)
755+
756+
self.eval()
757+
758+
@property
759+
def embedding_length(self) -> int:
760+
return self.__embedding_length
761+
762+
def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
763+
"""Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
764+
only if embeddings are non-static."""
765+
766+
# TODO: remove in future versions
767+
if not hasattr(self, "locked_dropout"):
768+
self.locked_dropout = None
769+
if not hasattr(self, "word_dropout"):
770+
self.word_dropout = None
771+
772+
if type(sentences) is Sentence:
773+
sentences = [sentences]
774+
775+
self.zero_grad() # is it necessary?
776+
777+
# embed words in the sentence
778+
self.embeddings.embed(sentences)
779+
780+
lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
781+
longest_token_sequence_in_batch: int = max(lengths)
782+
783+
pre_allocated_zero_tensor = torch.zeros(
784+
self.embeddings.embedding_length * longest_token_sequence_in_batch,
785+
dtype=torch.float,
786+
device=flair.device,
787+
)
788+
789+
all_embs: List[torch.Tensor] = list()
790+
for sentence in sentences:
791+
all_embs += [
792+
emb for token in sentence for emb in token.get_each_embedding()
793+
]
794+
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
795+
796+
if nb_padding_tokens > 0:
797+
t = pre_allocated_zero_tensor[
798+
: self.embeddings.embedding_length * nb_padding_tokens
799+
]
800+
all_embs.append(t)
801+
802+
sentence_tensor = torch.cat(all_embs).view(
803+
[
804+
len(sentences),
805+
longest_token_sequence_in_batch,
806+
self.embeddings.embedding_length,
807+
]
808+
)
809+
810+
# before-RNN dropout
811+
if self.dropout:
812+
sentence_tensor = self.dropout(sentence_tensor)
813+
if self.locked_dropout:
814+
sentence_tensor = self.locked_dropout(sentence_tensor)
815+
if self.word_dropout:
816+
sentence_tensor = self.word_dropout(sentence_tensor)
817+
818+
# reproject if set
819+
if self.reproject_words:
820+
sentence_tensor = self.word_reprojection_map(sentence_tensor)
821+
822+
# push CNN
823+
x = sentence_tensor
824+
x = x.permute(0, 2, 1)
825+
826+
rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs]
827+
outputs = torch.cat(rep, 1)
828+
829+
outputs = outputs.reshape(outputs.size(0), -1)
830+
831+
# after-CNN dropout
832+
if self.dropout:
833+
outputs = self.dropout(outputs)
834+
if self.locked_dropout:
835+
outputs = self.locked_dropout(outputs)
836+
837+
# extract embeddings from CNN
838+
for sentence_no, length in enumerate(lengths):
839+
embedding = outputs[sentence_no]
840+
841+
if self.static_embeddings:
842+
embedding = embedding.detach()
843+
844+
sentence = sentences[sentence_no]
845+
sentence.set_embedding(self.name, embedding)
846+
847+
def _apply(self, fn):
848+
for child_module in self.children():
849+
child_module._apply(fn)

flair/models/text_classification_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353

5454
super(TextClassifier, self).__init__()
5555

56-
self.document_embeddings: flair.embeddings.DocumentRNNEmbeddings = document_embeddings
56+
self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
5757
self.label_dictionary: Dictionary = label_dictionary
5858
self.label_type = label_type
5959

tests/test_embeddings.py

+18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
FlairEmbeddings,
1010
DocumentRNNEmbeddings,
1111
DocumentLMEmbeddings, TransformerWordEmbeddings, TransformerDocumentEmbeddings,
12+
DocumentCNNEmbeddings,
1213
)
1314

1415
from flair.data import Sentence, Dictionary
@@ -287,4 +288,21 @@ def test_transformer_document_embeddings():
287288

288289
sentence.clear_embeddings()
289290

291+
del embeddings
292+
293+
def test_document_cnn_embeddings():
294+
sentence: Sentence = Sentence("I love Berlin. Berlin is a great place to live.")
295+
296+
embeddings: DocumentCNNEmbeddings = DocumentCNNEmbeddings(
297+
[glove, flair_embedding], kernels=((50, 2), (50, 3))
298+
)
299+
300+
embeddings.embed(sentence)
301+
302+
assert len(sentence.get_embedding()) == 100
303+
assert len(sentence.get_embedding()) == embeddings.embedding_length
304+
305+
sentence.clear_embeddings()
306+
307+
assert len(sentence.get_embedding()) == 0
290308
del embeddings

0 commit comments

Comments
 (0)