Skip to content

Commit

Permalink
Add model heterogeneity settings in text modality.
Browse files Browse the repository at this point in the history
  • Loading branch information
XinghaoWu committed Feb 22, 2025
1 parent 3fb1596 commit c808514
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
64 changes: 64 additions & 0 deletions system/flcore/trainmodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,70 @@ def forward(self, x):

# ====================================================================================================================

class GRUNet(nn.Module):
def __init__(self, hidden_dim, num_layers=2, bidirectional=False, dropout=0.2,
padding_idx=0, vocab_size=98635, num_classes=10):
super().__init__()

self.dropout = nn.Dropout(dropout)
self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx)
self.gru = nn.GRU(input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
dims = hidden_dim * 2 if bidirectional else hidden_dim
self.fc = nn.Linear(dims, num_classes)

def forward(self, x):
if isinstance(x, list):
text, text_lengths = x
else:
text, text_lengths = x, [x.shape[1] for _ in range(x.shape[0])]

embedded = self.embedding(text)

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)
packed_output, hidden = self.gru(packed_embedded)

if isinstance(hidden, tuple): # LSTM 返回 (hidden, cell),GRU 只返回 hidden
hidden = hidden[0]

if self.gru.bidirectional:
hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
else:
hidden = hidden[-1, :, :]

hidden = self.dropout(hidden)
output = self.fc(hidden)
output = F.log_softmax(output, dim=1)

return output


# ====================================================================================================================

class TextLogisticRegression(nn.Module):
def __init__(self, hidden_dim, vocab_size=98635, num_classes=10):
super(TextLogisticRegression, self).__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.fc = nn.Linear(hidden_dim, num_classes)

def forward(self, x):
if isinstance(x, list):
text, _ = x
else:
text = x

embedded = self.embedding(text)
avg_embedding = embedded.mean(dim=1) # 取句子 token 平均表示
output = self.fc(avg_embedding)
output = F.log_softmax(output, dim=1)

return output

# ====================================================================================================================

# class linear(Function):
# @staticmethod
Expand Down
38 changes: 32 additions & 6 deletions system/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,42 @@ def run(args):
'torchvision.models.vit_b_32(image_size=32, num_classes=args.num_classes)'
]

elif args.model_family == "NLP_all":
elif args.model_family == "HtFE-txt-all":
args.models = [
'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)',
'TextCNN(hidden_dim=args.feature_dim, max_len=args.max_len, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)'
'TextCNN(hidden_dim=args.feature_dim, max_len=args.max_len, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)',
'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'GRUNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)'
]

elif args.model_family == "HtFE-txt-6":
args.models = [
'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)',
'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'GRUNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)'
]

elif args.model_family == "HtFE-txt-2":
args.models = [
'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)'
]

elif args.model_family == "HtFE-txt-4":
args.models = [
'fastText(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'TextLogisticRegression(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'LSTMNet(hidden_dim=args.feature_dim, vocab_size=args.vocab_size, num_classes=args.num_classes)',
'BiLSTM_TextClassification(input_size=args.vocab_size, hidden_size=args.feature_dim, output_size=args.num_classes, num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, embedding_length=args.feature_dim)'
]

elif args.model_family == "NLP_Transformers-nhead=8":
elif args.model_family == "HtFE-txt-5-1":
args.models = [
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=1, num_classes=args.num_classes, max_len=args.max_len)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)',
Expand All @@ -160,7 +186,7 @@ def run(args):
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=8, nlayers=16, num_classes=args.num_classes, max_len=args.max_len)',
]

elif args.model_family == "NLP_Transformers-nlayers=4":
elif args.model_family == "HtFE-txt-5-2":
args.models = [
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=1, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=2, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)',
Expand All @@ -169,7 +195,7 @@ def run(args):
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=16, nlayers=4, num_classes=args.num_classes, max_len=args.max_len)',
]

elif args.model_family == "NLP_Transformers":
elif args.model_family == "HtFE-txt-5-3":
args.models = [
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=1, nlayers=1, num_classes=args.num_classes, max_len=args.max_len)',
'TransformerModel(ntoken=args.vocab_size, d_model=args.feature_dim, nhead=2, nlayers=2, num_classes=args.num_classes, max_len=args.max_len)',
Expand Down

0 comments on commit c808514

Please sign in to comment.