From c8085144c276adc6f221118c6f48182f1ecda4d6 Mon Sep 17 00:00:00 2001 From: XinghaoWu Date: Sat, 22 Feb 2025 15:43:50 +0800 Subject: [PATCH] Add model heterogeneity settings in text modality. --- system/flcore/trainmodel/models.py | 64 ++++++++++++++++++++++++++++++ system/main.py | 38 +++++++++++++++--- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/system/flcore/trainmodel/models.py b/system/flcore/trainmodel/models.py index fb80b8e..cbd59f0 100644 --- a/system/flcore/trainmodel/models.py +++ b/system/flcore/trainmodel/models.py @@ -571,6 +571,70 @@ def forward(self, x): # ==================================================================================================================== +class GRUNet(nn.Module): + def __init__(self, hidden_dim, num_layers=2, bidirectional=False, dropout=0.2, + padding_idx=0, vocab_size=98635, num_classes=10): + super().__init__() + + self.dropout = nn.Dropout(dropout) + self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx) + self.gru = nn.GRU(input_size=hidden_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + bidirectional=bidirectional, + dropout=dropout, + batch_first=True) + dims = hidden_dim * 2 if bidirectional else hidden_dim + self.fc = nn.Linear(dims, num_classes) + + def forward(self, x): + if isinstance(x, list): + text, text_lengths = x + else: + text, text_lengths = x, [x.shape[1] for _ in range(x.shape[0])] + + embedded = self.embedding(text) + + packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False) + packed_output, hidden = self.gru(packed_embedded) + + if isinstance(hidden, tuple): # LSTM 返回 (hidden, cell),GRU 只返回 hidden + hidden = hidden[0] + + if self.gru.bidirectional: + hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) + else: + hidden = hidden[-1, :, :] + + hidden = self.dropout(hidden) + output = self.fc(hidden) + output = F.log_softmax(output, dim=1) + + return output + + +# ==================================================================================================================== + +class TextLogisticRegression(nn.Module): + def __init__(self, hidden_dim, vocab_size=98635, num_classes=10): + super(TextLogisticRegression, self).__init__() + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.fc = nn.Linear(hidden_dim, num_classes) + + def forward(self, x): + if isinstance(x, list): + text, _ = x + else: + text = x + + embedded = self.embedding(text) + avg_embedding = embedded.mean(dim=1) # 取句子 token 平均表示 + output = self.fc(avg_embedding) + output = F.log_softmax(output, dim=1) + + return output + +# ==================================================================================================================== # class linear(Function): # @staticmethod diff --git a/system/main.py b/system/main.py index 4d366f5..db15e0f 100644 --- a/system/main.py +++ b/system/main.py @@ -142,16 +142,42 @@ def run(args): 'torchvision.models.vit_b_32(image_size=32, num_classes=args.num_classes)' ] - elif args.model_family == "NLP_all": + elif args.model_family == "HtFE-txt-all": args.models = [ 'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', 'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', 'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)', - 'TextCNN(hidden_dim=args.feature_dim, max_len=args.max_len, vocab_size=args.vocab_size, num_classes=args.num_classes)', - 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)' + 'TextCNN(hidden_dim=args.feature_dim, max_len=args.max_len, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)', + 'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'GRUNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)' + ] + + elif args.model_family == "HtFE-txt-6": + args.models = [ + 'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)', + 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)', + 'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'GRUNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)' + ] + + elif args.model_family == "HtFE-txt-2": + args.models = [ + 'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)' + ] + + elif args.model_family == "HtFE-txt-4": + args.models = [ + 'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)', + 'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)' ] - elif args.model_family == "NLP_Transformers-nhead=8": + elif args.model_family == "HtFE-txt-5-1": args.models = [ 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=1, num_classes=args.num_classes, max_len=args.max_len)', 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)', @@ -160,7 +186,7 @@ def run(args): 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=16, num_classes=args.num_classes, max_len=args.max_len)', ] - elif args.model_family == "NLP_Transformers-nlayers=4": + elif args.model_family == "HtFE-txt-5-2": args.models = [ 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=1, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)', 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=2, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)', @@ -169,7 +195,7 @@ def run(args): 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=16, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)', ] - elif args.model_family == "NLP_Transformers": + elif args.model_family == "HtFE-txt-5-3": args.models = [ 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=1, nlayers=1, num_classes=args.num_classes, max_len=args.max_len)', 'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=2, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)',