Skip to content

Commit

Permalink
1. wandb project name
Browse files Browse the repository at this point in the history
2. sdp vocabs auto build
3. fix srl distill
  • Loading branch information
AlongWY committed Jan 3, 2021
1 parent 0f3c956 commit 8f331be
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ltp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

__version__ = '4.1.3.post1'
__version__ = '4.1.4'

from . import const
from . import nn, utils
Expand Down
16 changes: 13 additions & 3 deletions ltp/data/dataset/conllu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def build_vocabs(data_dir, train_file, dev_file=None, test_file=None, min_freq=5
'word': (1, Counter()), 'lemma': (2, Counter()), 'upos': (3, Counter()),
'xpos': (4, Counter()), 'feats': (5, Counter()), 'deprel': (7, Counter()),
# FOR CHAR FEATS
'word_char': (1, Counter())
'word_char': (1, Counter()),
# DEPS
'deps': (8, Counter())
}

if any([os.path.exists(os.path.join(data_dir, 'vocabs', f'{key}.txt')) for key in counters]):
Expand All @@ -56,6 +58,9 @@ def build_vocabs(data_dir, train_file, dev_file=None, test_file=None, min_freq=5
for name, (row, counter) in counters.items():
if 'char' in name:
counter.update(itertools.chain(*values[row]))
elif 'deps' == name:
deps = [[label.split(':', maxsplit=1)[1] for label in dep.split('|')] for dep in values[row]]
counter.update(itertools.chain(*deps))
else:
counter.update(values[row])

Expand Down Expand Up @@ -95,12 +100,17 @@ def _info(self):
'xpos': self.config.xpos,
'feats': self.config.feats,
'deprel': self.config.deprel,
'deps': self.config.deps
}

for key in feats:
if feats[key] is None:
feats[key] = os.path.join(self.config.data_dir, 'vocabs', f'{key}.txt')

deps_rel_feature = create_feature(feats['deps'])
if deps_rel_feature.num_classes > 1:
self.config.deps = feats['deps']

return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
Expand All @@ -117,8 +127,8 @@ def _info(self):
{
'id': datasets.Value('int64'),
'head': datasets.Value("int64"),
'rel': create_feature(self.config.deps)
} if self.config.deps else create_feature(self.config.deps)
'rel': deps_rel_feature
} if deps_rel_feature.num_classes > 1 else create_feature(None)
),
"misc": datasets.Sequence(datasets.Value("string")),
}
Expand Down
8 changes: 6 additions & 2 deletions ltp/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,18 @@ def build_ner_distill_dataset(args):
model.eval()
model.freeze()

dataset, metric = task_named_entity_recognition.build_dataset(model, args.ner_data_dir, task_info.task_name)
dataset, metric = task_named_entity_recognition.build_dataset(
model, args.ner_data_dir,
task_named_entity_recognition.task_info.task_name
)
train_dataloader = torch.utils.data.DataLoader(
dataset[datasets.Split.TRAIN],
batch_size=args.batch_size,
collate_fn=collate,
num_workers=args.num_workers
)

output = os.path.join(args.ner_data_dir, task_info.task_name, 'output.npz')
output = os.path.join(args.ner_data_dir, task_named_entity_recognition.task_info.task_name, 'output.npz')

if torch.cuda.is_available():
model.cuda()
Expand Down Expand Up @@ -246,6 +249,7 @@ def add_task_specific_args(parent_parser):
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
Expand Down
28 changes: 12 additions & 16 deletions ltp/multitask_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,11 @@ def distill_linear(batch, result, target, temperature_scheduler, model: Model =
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

temperature = temperature_scheduler(active_logits, active_target_logits)
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature)

transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)
kd_loss = kd_mse_loss(active_logits, active_target_logits)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)
crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions)

return kd_loss + crf_loss
else:
Expand Down Expand Up @@ -168,7 +163,7 @@ def distill_matrix_crf(batch, result, target, temperature_scheduler, model: Mode
index = logits_mask[:, 0]
logits_mask = logits_mask[index]

s_rel, labels = result.arc_logits, result.labels
s_rel, labels = result.rel_logits, result.labels
t_rel = target

active_logits = s_rel[logits_mask]
Expand All @@ -181,13 +176,13 @@ def distill_matrix_crf(batch, result, target, temperature_scheduler, model: Mode
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)
# transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
# s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
# e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)
crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions)
return kd_loss + crf_loss


Expand Down Expand Up @@ -359,6 +354,7 @@ def add_task_specific_args(parent_parser):
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')

parser.add_argument('--disable_seg', action='store_true')
parser.add_argument('--disable_pos', action='store_true')
Expand Down
1 change: 1 addition & 0 deletions ltp/task_dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
Expand Down
1 change: 1 addition & 0 deletions ltp/task_named_entity_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
Expand Down
1 change: 1 addition & 0 deletions ltp/task_part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
parser.add_argument('--cpus_per_trial', type=float, default=1.0)
Expand Down
1 change: 1 addition & 0 deletions ltp/task_segmention.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
Expand Down
1 change: 1 addition & 0 deletions ltp/task_semantic_dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--seed', type=int, default=19980524)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
Expand Down
1 change: 1 addition & 0 deletions ltp/task_semantic_role_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--tune', action='store_true')
parser.add_argument('--offline', action='store_true')
parser.add_argument('--project', type=str, default='ltp')
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
parser.add_argument('--cpus_per_trial', type=float, default=1.0)
Expand Down
2 changes: 1 addition & 1 deletion ltp/utils/common_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def common_train(args, model_class, task_info: TaskInfo, build_method=default_bu
try:
import wandb
logger = loggers.WandbLogger(
project='ltp4',
project=args.project,
save_dir='lightning_logs',
name=f'{task_info.task_name}_{this_time}',
offline=args.offline
Expand Down

0 comments on commit 8f331be

Please sign in to comment.