Skip to content

Commit

Permalink
add commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nmquang003 committed Nov 29, 2024
2 parents 392b364 + 2eb56f3 commit ea2c2f7
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 114 deletions.
Binary file added data/__pycache__/BaseData.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/FewRel.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/TACRED.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
175 changes: 98 additions & 77 deletions models/EoE.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,14 @@ def get_prompt_indices(self, prelogits, expert_id=0):

return indices, scores_over_tasks, class_indices_over_tasks

def get_expert_indices(self, prelogits, task_id=None):
if task_id == None:
task_id = self.num_tasks
logits = self.tii_head[task_id](prelogits) # [n, task_num]
def get_expert_indices(self, prelogits, task_idx=None):
if task_idx == None:
task_idx = self.num_tasks
logits = self.tii_head[task_idx](prelogits) # [n, task_num]
scores, indices = torch.max(logits, dim=1)
return scores, indices
return scores, indices // 4

def forward(self, input_ids, attention_mask=None, labels=None, oracle=False, **kwargs):

batch_size, _ = input_ids.shape
if attention_mask is None:
attention_mask = input_ids != 0
Expand Down Expand Up @@ -215,83 +214,105 @@ def forward(self, input_ids, attention_mask=None, labels=None, oracle=False, **k
if "extract_mode" in kwargs:
del kwargs["extract_mode"]
return hidden_states
if "use_tii_head" in kwargs and kwargs["use_tii_head"]:
hidden_states = self.feature_extractor(
input_ids=kwargs["input_ids_without_marker"],
indices=None,
use_origin=True,
**kwargs
)

indices = self.get_expert_indices(hidden_states, self.num_tasks)

all_score_over_task = []
all_score_over_class = []
all_logits = []
for e_id in range(-1, self.num_tasks + 1):
if e_id == -1:
indices = None
use_origin = True
kwargs.update({"extract_mode": "entity"})
elif e_id == 0:
indices = None
use_origin = False
else:
indices = [e_id] * batch_size
use_origin = False
hidden_states = self.feature_extractor(
input_ids=input_ids if e_id != -1 else kwargs["input_ids_without_marker"],
input_ids=input_ids,
indices=indices,
use_origin=use_origin,
use_origin=False,
**kwargs
)
if "extract_mode" in kwargs:
del kwargs["extract_mode"]
_, scores_over_tasks, scores_over_classes = self.get_prompt_indices(hidden_states, expert_id=e_id) # scores, class indices
scores_over_tasks = scores_over_tasks.transpose(-1, -2)
scores_over_classes = scores_over_classes.transpose(-1, -2)
if e_id != -1:
scores_over_tasks[:, :e_id] = float('inf') # no seen task
logits = self.classifier[e_id](hidden_states)[:, :self.class_per_task]
all_logits.append(logits)
all_score_over_task.append(scores_over_tasks)
all_score_over_class.append(scores_over_classes)
all_score_over_task = torch.stack(all_score_over_task, dim=1) # (batch, expert_num, task_num)
all_score_over_class = torch.stack(all_score_over_class, dim=1) # (batch, expert_num, task_num)
all_logits = torch.stack(all_logits, dim=1)
indices = []
# expert0_score_over_task = all_score_over_task[:, 0, :] # (batch, task_num)
expert_values, expert_indices = torch.topk(all_score_over_task, dim=-1, k=all_score_over_task.shape[-1],
largest=False)
expert_values = expert_values.tolist()
expert_indices = expert_indices.tolist()
for i in range(batch_size):
bert_indices = expert_indices[i][0]
task_indices = expert_indices[i][1]
if self.default_expert == "bert":
default_indices = copy.deepcopy(bert_indices)
else:
default_indices = copy.deepcopy(task_indices)
min_task = min(bert_indices[0], task_indices[0])
max_task = max(bert_indices[0], task_indices[0])
# valid_task_id = [min_task, max_task]
cur_min_expert = self.shift_expert_id(min_task)
if bert_indices[0] != task_indices[0] and cur_min_expert > 1:
cur_ans = []
for j in range(0, cur_min_expert + 1):
if j <= self.max_expert: # self.max_expert==1 --> default expert
for k in expert_indices[i][j]:
if k >= min_task:
cur_ans.append(k)
break
cur_count = Counter(cur_ans)
most_common_element = cur_count.most_common(1)
if most_common_element[0][1] == cur_ans.count(default_indices[0]):
indices.append(default_indices[0])

logits = self.classifier[self.num_tasks](hidden_states)
preds = logits.max(dim=-1)[1] + self.class_per_task * indices
expert_task_preds=None
expert_class_preds=None
else:
all_score_over_task = []
all_score_over_class = []
all_logits = []
for e_id in range(-1, self.num_tasks + 1):
if e_id == -1:
indices = None
use_origin = True
kwargs.update({"extract_mode": "entity"})
elif e_id == 0:
indices = None
use_origin = False
else:
indices.append(most_common_element[0][0])
else:
indices.append(default_indices[0])
# indices.append(expert_indices[i][1][0])
indices = torch.LongTensor(indices).to(self.device)
if oracle:
task_idx = kwargs["task_idx"]
indices = torch.LongTensor([task_idx] * batch_size).to(self.device)
idx = torch.arange(batch_size).to(self.device)
all_logits = all_logits[idx, indices]
preds = all_logits.max(dim=-1)[1] + self.class_per_task * indices
indices = indices.tolist() if isinstance(indices, torch.Tensor) else indices
indices = [e_id] * batch_size
use_origin = False
hidden_states = self.feature_extractor(
input_ids=input_ids if e_id != -1 else kwargs["input_ids_without_marker"],
indices=indices,
use_origin=use_origin,
**kwargs
)
if "extract_mode" in kwargs:
del kwargs["extract_mode"]
_, scores_over_tasks, scores_over_classes = self.get_prompt_indices(hidden_states, expert_id=e_id) # scores, class indices
scores_over_tasks = scores_over_tasks.transpose(-1, -2)
scores_over_classes = scores_over_classes.transpose(-1, -2)
if e_id != -1:
scores_over_tasks[:, :e_id] = float('inf') # no seen task
logits = self.classifier[e_id](hidden_states)[:, :self.class_per_task]
all_logits.append(logits)
all_score_over_task.append(scores_over_tasks)
all_score_over_class.append(scores_over_classes)
all_score_over_task = torch.stack(all_score_over_task, dim=1) # (batch, expert_num, task_num)
all_score_over_class = torch.stack(all_score_over_class, dim=1) # (batch, expert_num, task_num)
all_logits = torch.stack(all_logits, dim=1)
indices = []
# expert0_score_over_task = all_score_over_task[:, 0, :] # (batch, task_num)
expert_values, expert_indices = torch.topk(all_score_over_task, dim=-1, k=all_score_over_task.shape[-1],
largest=False)
expert_values = expert_values.tolist()
expert_indices = expert_indices.tolist()
for i in range(batch_size):
bert_indices = expert_indices[i][0]
task_indices = expert_indices[i][1]
if self.default_expert == "bert":
default_indices = copy.deepcopy(bert_indices)
else:
default_indices = copy.deepcopy(task_indices)
min_task = min(bert_indices[0], task_indices[0])
max_task = max(bert_indices[0], task_indices[0])
# valid_task_id = [min_task, max_task]
cur_min_expert = self.shift_expert_id(min_task)
if bert_indices[0] != task_indices[0] and cur_min_expert > 1:
cur_ans = []
for j in range(0, cur_min_expert + 1):
if j <= self.max_expert: # self.max_expert==1 --> default expert
for k in expert_indices[i][j]:
if k >= min_task:
cur_ans.append(k)
break
cur_count = Counter(cur_ans)
most_common_element = cur_count.most_common(1)
if most_common_element[0][1] == cur_ans.count(default_indices[0]):
indices.append(default_indices[0])
else:
indices.append(most_common_element[0][0])
else:
indices.append(default_indices[0])
# indices.append(expert_indices[i][1][0])
indices = torch.LongTensor(indices).to(self.device)
if oracle:
task_idx = kwargs["task_idx"]
indices = torch.LongTensor([task_idx] * batch_size).to(self.device)
idx = torch.arange(batch_size).to(self.device)
all_logits = all_logits[idx, indices]
preds = all_logits.max(dim=-1)[1] + self.class_per_task * indices
indices = indices.tolist() if isinstance(indices, torch.Tensor) else indices

return ExpertOutput(
preds=preds,
indices=indices,
Expand Down
Binary file added models/__pycache__/EoE.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/ExpertModel.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
75 changes: 38 additions & 37 deletions trainers/EoETrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def run(self, data, model, tokenizer, label_order, seed=None, train=True):

self.statistic(model, train_dataset, default_data_collator)

# training tii
mean, cov, _, _ = self.get_mean_and_cov(model, train_dataset, default_data_collator, -1)
self.train_tii(model, mean, cov, num_sample=1000)
means = model.expert_distribution[0]['class_mean']
cov = model.expert_distribution[0]['accumulate_cov']
self.train_tii(model, means, cov, task_idx, num_sample=1000)

cur_test_data = data.filter(cur_labels, 'test')
history_test_data = data.filter(seen_labels, 'test')
Expand Down Expand Up @@ -199,7 +199,9 @@ def train(self, model, train_dataset, data_collator):

progress_bar.close()

def train_tii(self, model, means, cov, num_sample=1000):
def train_tii(self, model, means, cov, task_idx, num_sample=1000):
# print("-----1", means.shape)
# print("-----2", cov.shape)
# Dữ liệu đầu vào (mỗi mẫu là một vector)
all_samples = []
all_labels = []
Expand All @@ -210,7 +212,8 @@ def train_tii(self, model, means, cov, num_sample=1000):
for j in range(len(means)): # Task index
for k in range(len(means[0])): # Class index
mean = means[j][k].cuda() # Mean của lớp thứ k trong task thứ j

# print("-----3", mean.shape)
# print("-----4", cov_regularized.shape)
# Khởi tạo phân phối Gaussian đa biến
mvn = MultivariateNormal(mean, covariance_matrix=cov_regularized)

Expand All @@ -223,50 +226,47 @@ def train_tii(self, model, means, cov, num_sample=1000):
all_samples = torch.cat(all_samples, dim=0) # (total_samples, feature_dim)
all_labels = torch.cat(all_labels, dim=0) # (total_samples,)

logger.info("***** Running training tii *****")
logger.info(f"***** Running training tii[{task_idx}] *****")
logger.info(f" Num examples per each class = {num_sample}")
logger.info(f" Num Epochs = {self.args.num_train_epochs_tii}")
logger.info(f" Train batch size = {self.args.train_batch_size_tii}")

for task_idx in range(self.args.num_tasks):
# Chọn dữ liệu của các task từ 0 đến task_idx
task_samples = all_samples[: (task_idx + 1) * num_sample]
task_labels = all_labels[: (task_idx + 1) * num_sample]
dataset = TensorDataset(task_samples, task_labels)
dataloader = DataLoader(dataset, batch_size=self.args.train_batch_size_tii, shuffle=True)

tii_head = model.tii_head[task_idx]
max_steps = len(dataloader) * self.args.num_train_epochs_tii
progress_bar = tqdm(range(max_steps))
# Chọn dữ liệu của các task từ 0 đến task_idx
task_samples = all_samples[: (task_idx + 1) * num_sample]
task_labels = all_labels[: (task_idx + 1) * num_sample]
dataset = TensorDataset(task_samples, task_labels)
dataloader = DataLoader(dataset, batch_size=self.args.train_batch_size_tii, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(tii_head.parameters(), lr=0.001)

for epoch in range(self.args.num_train_epochs_tii):
tii_head.train()
for inputs, labels in dataloader:
inputs, labels = inputs.cuda(), labels.cuda()
tii_head = model.tii_head[task_idx]
max_steps = len(dataloader) * self.args.num_train_epochs_tii
progress_bar = tqdm(range(max_steps))

# Zero gradients
optimizer.zero_grad()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(tii_head.parameters(), lr=0.001)

# Forward pass
outputs = tii_head(inputs) # Tiến hành inference với tii_head
for epoch in range(self.args.num_train_epochs_tii):
tii_head.train()
for inputs, labels in dataloader:
inputs, labels = inputs.cuda(), labels.cuda()

# Tính toán loss và backpropagation
loss = criterion(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(tii_head.parameters(), self.args.max_grad_norm)
optimizer.step()
# Zero gradients
optimizer.zero_grad()

progress_bar.update(1)
progress_bar.set_postfix({"Loss": loss.item()})
progress_bar.close()
# Forward pass
outputs = tii_head(inputs) # Tiến hành inference với tii_head

# Tính toán loss và backpropagation
loss = criterion(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(tii_head.parameters(), self.args.max_grad_norm)
optimizer.step()

progress_bar.update(1)
progress_bar.set_postfix({"Loss": loss.item()})
progress_bar.close()

@torch.no_grad()
def eval(self, model, eval_dataset, data_collator, seen_labels, label2task_id, oracle=False):
def eval(self, model, eval_dataset, data_collator, seen_labels, label2task_id, oracle=False, use_tii_head=True):
eval_dataloader = DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
Expand Down Expand Up @@ -296,6 +296,7 @@ def eval(self, model, eval_dataset, data_collator, seen_labels, label2task_id, o
inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
if oracle:
inputs.update({"oracle": True, "task_idx": self.task_idx})
inputs.update({"use_tii_head": use_tii_head})
outputs = model(**inputs)

hit_pred = outputs.indices
Expand All @@ -320,7 +321,7 @@ def eval(self, model, eval_dataset, data_collator, seen_labels, label2task_id, o
logger.info("Acc {}".format(acc))
logger.info("Hit Acc {}".format(hit_acc))

if not oracle:
if not oracle and not use_tii_head:
expert_task_preds = torch.cat(expert_task_preds, dim=0).tolist()
expert_class_preds = torch.cat(expert_class_preds, dim=0).tolist()
save_data = {
Expand Down
Binary file added trainers/__pycache__/BaseTrainer.cpython-310.pyc
Binary file not shown.
Binary file added trainers/__pycache__/EoETrainer.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added trainers/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added utils/__pycache__/DataCollator.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/Distance.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/__init__.cpython-310.pyc
Binary file not shown.

0 comments on commit ea2c2f7

Please sign in to comment.