@@ -688,3 +688,162 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
688
688
def embedding_length (self ) -> int :
689
689
"""Returns the length of the embedding vector."""
690
690
return self .model .get_sentence_embedding_dimension ()
691
+
692
+
693
+ class DocumentCNNEmbeddings (DocumentEmbeddings ):
694
+ def __init__ (
695
+ self ,
696
+ embeddings : List [TokenEmbeddings ],
697
+ kernels = ((100 , 3 ), (100 , 4 ), (100 , 5 )),
698
+ reproject_words : bool = True ,
699
+ reproject_words_dimension : int = None ,
700
+ dropout : float = 0.5 ,
701
+ word_dropout : float = 0.0 ,
702
+ locked_dropout : float = 0.0 ,
703
+ fine_tune : bool = True ,
704
+ ):
705
+ """The constructor takes a list of embeddings to be combined.
706
+ :param embeddings: a list of token embeddings
707
+ :param kernels: list of (number of kernels, kernel size)
708
+ :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
709
+ layer before putting them into the rnn or not
710
+ :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
711
+ dimension as before will be taken.
712
+ :param dropout: the dropout value to be used
713
+ :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
714
+ :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
715
+ """
716
+ super ().__init__ ()
717
+
718
+ self .embeddings : StackedEmbeddings = StackedEmbeddings (embeddings = embeddings )
719
+ self .length_of_all_token_embeddings : int = self .embeddings .embedding_length
720
+
721
+ self .kernels = kernels
722
+ self .reproject_words = reproject_words
723
+
724
+ self .static_embeddings = False if fine_tune else True
725
+
726
+ self .embeddings_dimension : int = self .length_of_all_token_embeddings
727
+ if self .reproject_words and reproject_words_dimension is not None :
728
+ self .embeddings_dimension = reproject_words_dimension
729
+
730
+ self .word_reprojection_map = torch .nn .Linear (
731
+ self .length_of_all_token_embeddings , self .embeddings_dimension
732
+ )
733
+
734
+ # CNN
735
+ self .__embedding_length : int = sum ([kernel_num for kernel_num , kernel_size in self .kernels ])
736
+ self .convs = torch .nn .ModuleList (
737
+ [
738
+ torch .nn .Conv1d (self .embeddings_dimension , kernel_num , kernel_size ) for kernel_num , kernel_size in self .kernels
739
+ ]
740
+ )
741
+ self .pool = torch .nn .AdaptiveMaxPool1d (1 )
742
+
743
+ self .name = "document_cnn"
744
+
745
+ # dropouts
746
+ self .dropout = torch .nn .Dropout (dropout ) if dropout > 0.0 else None
747
+ self .locked_dropout = (
748
+ LockedDropout (locked_dropout ) if locked_dropout > 0.0 else None
749
+ )
750
+ self .word_dropout = WordDropout (word_dropout ) if word_dropout > 0.0 else None
751
+
752
+ torch .nn .init .xavier_uniform_ (self .word_reprojection_map .weight )
753
+
754
+ self .to (flair .device )
755
+
756
+ self .eval ()
757
+
758
+ @property
759
+ def embedding_length (self ) -> int :
760
+ return self .__embedding_length
761
+
762
+ def _add_embeddings_internal (self , sentences : Union [List [Sentence ], Sentence ]):
763
+ """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
764
+ only if embeddings are non-static."""
765
+
766
+ # TODO: remove in future versions
767
+ if not hasattr (self , "locked_dropout" ):
768
+ self .locked_dropout = None
769
+ if not hasattr (self , "word_dropout" ):
770
+ self .word_dropout = None
771
+
772
+ if type (sentences ) is Sentence :
773
+ sentences = [sentences ]
774
+
775
+ self .zero_grad () # is it necessary?
776
+
777
+ # embed words in the sentence
778
+ self .embeddings .embed (sentences )
779
+
780
+ lengths : List [int ] = [len (sentence .tokens ) for sentence in sentences ]
781
+ longest_token_sequence_in_batch : int = max (lengths )
782
+
783
+ pre_allocated_zero_tensor = torch .zeros (
784
+ self .embeddings .embedding_length * longest_token_sequence_in_batch ,
785
+ dtype = torch .float ,
786
+ device = flair .device ,
787
+ )
788
+
789
+ all_embs : List [torch .Tensor ] = list ()
790
+ for sentence in sentences :
791
+ all_embs += [
792
+ emb for token in sentence for emb in token .get_each_embedding ()
793
+ ]
794
+ nb_padding_tokens = longest_token_sequence_in_batch - len (sentence )
795
+
796
+ if nb_padding_tokens > 0 :
797
+ t = pre_allocated_zero_tensor [
798
+ : self .embeddings .embedding_length * nb_padding_tokens
799
+ ]
800
+ all_embs .append (t )
801
+
802
+ sentence_tensor = torch .cat (all_embs ).view (
803
+ [
804
+ len (sentences ),
805
+ longest_token_sequence_in_batch ,
806
+ self .embeddings .embedding_length ,
807
+ ]
808
+ )
809
+
810
+ # before-RNN dropout
811
+ if self .dropout :
812
+ sentence_tensor = self .dropout (sentence_tensor )
813
+ if self .locked_dropout :
814
+ sentence_tensor = self .locked_dropout (sentence_tensor )
815
+ if self .word_dropout :
816
+ sentence_tensor = self .word_dropout (sentence_tensor )
817
+
818
+ # reproject if set
819
+ if self .reproject_words :
820
+ sentence_tensor = self .word_reprojection_map (sentence_tensor )
821
+
822
+ # push CNN
823
+ x = sentence_tensor
824
+ x = x .permute (0 , 2 , 1 )
825
+
826
+ rep = [self .pool (torch .nn .functional .relu (conv (x ))) for conv in self .convs ]
827
+ outputs = torch .cat (rep , 1 )
828
+
829
+ outputs = outputs .reshape (outputs .size (0 ), - 1 )
830
+
831
+ # after-CNN dropout
832
+ if self .dropout :
833
+ outputs = self .dropout (outputs )
834
+ if self .locked_dropout :
835
+ outputs = self .locked_dropout (outputs )
836
+
837
+ # extract embeddings from CNN
838
+ for sentence_no , length in enumerate (lengths ):
839
+ embedding = outputs [sentence_no ]
840
+
841
+ if self .static_embeddings :
842
+ embedding = embedding .detach ()
843
+
844
+ sentence = sentences [sentence_no ]
845
+ sentence .set_embedding (self .name , embedding )
846
+
847
+ def _apply (self , fn ):
848
+ for child_module in self .children ():
849
+ child_module ._apply (fn )
0 commit comments