Skip to content

Commit

Permalink
1. 修复了 conllu 自动处理数据的一些小问题
Browse files Browse the repository at this point in the history
2. 更新了 patch
3. 略微修改了蒸馏 loss
4. 加入 tune 的导入
  • Loading branch information
AlongWY committed Jan 22, 2021
1 parent c47b3f4 commit e5eef92
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 140 deletions.
9 changes: 6 additions & 3 deletions ltp/data/dataset/conllu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ def build_vocabs(data_dir, train_file, dev_file=None, test_file=None, min_freq=5
if 'char' in name:
counter.update(itertools.chain(*values[row]))
elif 'deps' == name:
deps = [[label.split(':', maxsplit=1)[1] for label in dep.split('|')] for dep in values[row]]
counter.update(itertools.chain(*deps))
try:
deps = [[label.split(':', maxsplit=1)[1] for label in dep.split('|')] for dep in values[row]]
counter.update(itertools.chain(*deps))
except:
counter.update('_')
else:
counter.update(values[row])

Expand Down Expand Up @@ -158,7 +161,7 @@ def _generate_examples(self, filepath):
id, words, lemma, upos, xpos, feats, head, deprel, deps, misc = [list(value) for value in zip(*block)]
if self.config.deps:
deps = [[label.split(':', maxsplit=1) for label in dep.split('|')] for dep in deps]
deps = [[{'id': depid, 'head': int(label[0]), 'rel': label[1]} for label in dep] for depid, dep in
deps = [[{'id': depid, 'head': int(label[0]), 'rel': label[-1]} for label in dep] for depid, dep in
enumerate(deps)]
deps = list(itertools.chain(*deps))
if any([dep['head'] >= len(words) for dep in deps]):
Expand Down
37 changes: 21 additions & 16 deletions ltp/multitask_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def kd_ce_loss(logits_S, logits_T, temperature=1):
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
p_T = F.softmax(beta_logits_T, dim=-1)
loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
return loss
loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1))
return (temperature * temperature * loss).sum(dim=-1).mean()


def kd_mse_loss(logits_S, logits_T, temperature=1):
Expand All @@ -53,11 +53,11 @@ def kd_mse_loss(logits_S, logits_T, temperature=1):
temperature = temperature.unsqueeze(-1)
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
loss = F.mse_loss(beta_logits_S, beta_logits_T)
return loss
loss = F.mse_loss(beta_logits_S, beta_logits_T, reduction='none')
return (temperature * temperature * loss).mean()


def flsw_temperature_scheduler_builder(beta, gamma, base_temperature=8, eps=1e-4, *args):
def flsw_temperature_scheduler_builder(beta=1, gamma=2, base_temperature=8, eps=1e-4, *args):
'''
adapted from arXiv:1911.07471
'''
Expand Down Expand Up @@ -90,11 +90,16 @@ def distill_linear(batch, result, target, temperature_scheduler, model: Model =
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

kd_loss = kd_mse_loss(active_logits, active_target_logits)
temperature = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions)
transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)

return kd_loss + crf_loss
else:
Expand Down Expand Up @@ -176,13 +181,13 @@ def distill_matrix_crf(batch, result, target, temperature_scheduler, model: Mode
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

# transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
# s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
# e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)
transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions)
crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)
return kd_loss + crf_loss


Expand Down Expand Up @@ -368,7 +373,7 @@ def add_task_specific_args(parent_parser):
parser.add_argument('--gpus_per_trial', type=float, default=1.0)
parser.add_argument('--cpus_per_trial', type=float, default=5.0)
parser.add_argument('--distill_beta', type=float, default=1.0)
parser.add_argument('--distill_gamma', type=float, default=1.0)
parser.add_argument('--distill_gamma', type=float, default=2.0)
parser.add_argument('--temperature', type=float, default=8.0)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--num_samples', type=int, default=10)
Expand Down
2 changes: 1 addition & 1 deletion ltp/nn/relative_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _shift(self, BD):
zero_pad = BD.new_zeros(bsz, n_head, max_len, 1)
BD = torch.cat([BD, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # bsz x n_head x (2max_len+1) x max_len
BD = BD[:, :, :-1, :].view(bsz, n_head, max_len, -1) # bsz x n_head x 2max_len x max_len
_, BD = torch.chunk(BD, dim=-1, chunks=2)
BD = BD[:, :, :, :max_len]
return BD


Expand Down
5 changes: 5 additions & 0 deletions ltp/patchs/patch_4_1_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def patch(ckpt):
setattr(model_config, 'ner_use_crf', False)
setattr(model_config, 'ner_crf_reduction', 'sum')

if ckpt['seg'][0].startswith('B'):
ckpt['seg'] = list(reversed(ckpt['seg']))
seg_classifier_weight = ckpt['model']['seg_classifier.weight']
ckpt['model']['seg_classifier.weight'] = seg_classifier_weight[[1, 0]]

for key, value in ckpt['model'].items():
key: str
if key.startswith('srl_classifier.mlp_rel_h'):
Expand Down
9 changes: 7 additions & 2 deletions ltp/plugins/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from .tune_checkpoint_reporter import TuneReportCallback
from .tune_checkpoint_reporter import TuneReportCheckpointCallback
import warnings

try:
from .tune_checkpoint_reporter import TuneReportCallback
from .tune_checkpoint_reporter import TuneReportCheckpointCallback
except Exception as e:
warnings.warn("install ray[tune] to use tune model hyper")
243 changes: 125 additions & 118 deletions ltp/utils/common_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,8 @@
from pytorch_lightning import Trainer, loggers
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from pytorch_lightning.utilities.cloud_io import load as pl_load

from ltp.utils.task_info import TaskInfo
from ltp.utils.method_builder import default_build_method
from ltp.plugins.tune import TuneReportCheckpointCallback

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -63,118 +57,131 @@ def common_train(args, model_class, task_info: TaskInfo, build_method=default_bu
trainer.test()


def tune_train_once(
config,
checkpoint_dir=None,
args: argparse.Namespace = None,
model_class: type = None,
build_method=None,
task_info: TaskInfo = None,
model_kwargs: dict = None,
resume: str = None,
**kwargs
):
if resume is None:
resume = 'all'
args_vars = vars(args)
args_vars.update(config)

pl.seed_everything(args.seed)
logger = [
loggers.CSVLogger(save_dir=tune.get_trial_dir(), name="", version="."),
loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version=".", default_hp_metric=False)
]
trainer_args = dict(
logger=logger,
progress_bar_refresh_rate=0,
callbacks=[
TuneReportCheckpointCallback(
metrics={f'tune_{task_info.metric_name}': f'val_{task_info.metric_name}'},
filename="tune.ckpt", on="validation_end"
)
try:
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ltp.plugins.tune import TuneReportCheckpointCallback


def tune_train_once(
config,
checkpoint_dir=None,
args: argparse.Namespace = None,
model_class: type = None,
build_method=None,
task_info: TaskInfo = None,
model_kwargs: dict = None,
resume: str = None,
**kwargs
):
if resume is None:
resume = 'all'
args_vars = vars(args)
args_vars.update(config)

pl.seed_everything(args.seed)
logger = [
loggers.CSVLogger(save_dir=tune.get_trial_dir(), name="", version="."),
loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version=".", default_hp_metric=False)
]
)
if checkpoint_dir and resume == 'all':
trainer_args['resume_from_checkpoint'] = os.path.join(checkpoint_dir, "tune.ckpt")

# fix slurm trainer
os.environ["SLURM_JOB_NAME"] = "bash"
model = model_class(args, **model_kwargs)
build_method(model, task_info)
trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args)
if checkpoint_dir and resume == 'model':
ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"), map_location=lambda storage, loc: storage)
model = model._load_model_state(ckpt)
trainer.current_epoch = ckpt["epoch"]
trainer.fit(model)


def tune_train(args, model_class, task_info: TaskInfo, build_method=default_build_method, model_kwargs: dict = None):
if model_kwargs is None:
model_kwargs = {}
args.data_dir = os.path.abspath(args.data_dir)
this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())

lr_quantity = 10 ** round(math.log(args.lr, 10))
config = {
"seed": tune.choice(list(range(10))),

# 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
"lr": tune.uniform(lr_quantity, 5 * lr_quantity),

# -1 for disable, 0.8 for Base/Small, 0.9 for Large
"layerwise_lr_decay_power": tune.choice([0.8, 0.9]),

# lr scheduler
"lr_scheduler": tune.choice(['linear_schedule_with_warmup', 'polynomial_decay_schedule_with_warmup'])
}
if torch.cuda.is_available():
resources_per_trial = {"cpu": args.cpus_per_trial, "gpu": args.gpus_per_trial}
else:
resources_per_trial = {"cpu": args.cpus_per_trial}
print("resources_per_trial", resources_per_trial)

analysis = tune.run(
tune.with_parameters(
tune_train_once,
args=args,
task_info=task_info,
model_class=model_class,
build_method=build_method,
model_kwargs=model_kwargs,
resume='all'
),
mode="max",
config=config,
num_samples=args.num_samples,
metric=f'tune_{task_info.metric_name}',
name=f'{task_info.task_name}_{this_time}',
local_dir='lightning_logs',
progress_reporter=CLIReporter(
parameter_columns=list(config.keys()),
metric_columns=[
"loss", f'tune_{task_info.metric_name}', "training_iteration"
trainer_args = dict(
logger=logger,
progress_bar_refresh_rate=0,
callbacks=[
TuneReportCheckpointCallback(
metrics={f'tune_{task_info.metric_name}': f'val_{task_info.metric_name}'},
filename="tune.ckpt", on="validation_end"
)
]
),
resources_per_trial=resources_per_trial,
scheduler=ASHAScheduler(
max_t=args.max_epochs,
grace_period=args.min_epochs
),
queue_trials=True,
keep_checkpoints_num=3,
checkpoint_score_attr=f'tune_{task_info.metric_name}'
)
print("Best hyperparameters found were: ", analysis.best_config)
)
if checkpoint_dir and resume == 'all':
trainer_args['resume_from_checkpoint'] = os.path.join(checkpoint_dir, "tune.ckpt")

# fix slurm trainer
os.environ["SLURM_JOB_NAME"] = "bash"
model = model_class(args, **model_kwargs)
build_method(model, task_info)
trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args)
if checkpoint_dir and resume == 'model':
ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"), map_location=lambda storage, loc: storage)
model = model._load_model_state(ckpt)
trainer.current_epoch = ckpt["epoch"]
trainer.fit(model)


def tune_train(args, model_class, task_info: TaskInfo, build_method=default_build_method,
model_kwargs: dict = None):
if model_kwargs is None:
model_kwargs = {}
args.data_dir = os.path.abspath(args.data_dir)
this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())

lr_quantity = 10 ** round(math.log(args.lr, 10))
config = {
"seed": tune.choice(list(range(10))),

# 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
"lr": tune.uniform(lr_quantity, 5 * lr_quantity),

# -1 for disable, 0.8 for Base/Small, 0.9 for Large
"layerwise_lr_decay_power": tune.choice([0.8, 0.9]),

# lr scheduler
"lr_scheduler": tune.choice(['linear_schedule_with_warmup', 'polynomial_decay_schedule_with_warmup'])
}
if torch.cuda.is_available():
resources_per_trial = {"cpu": args.cpus_per_trial, "gpu": args.gpus_per_trial}
else:
resources_per_trial = {"cpu": args.cpus_per_trial}
print("resources_per_trial", resources_per_trial)

analysis = tune.run(
tune.with_parameters(
tune_train_once,
args=args,
task_info=task_info,
model_class=model_class,
build_method=build_method,
model_kwargs=model_kwargs,
resume='all'
),
mode="max",
config=config,
num_samples=args.num_samples,
metric=f'tune_{task_info.metric_name}',
name=f'{task_info.task_name}_{this_time}',
local_dir='lightning_logs',
progress_reporter=CLIReporter(
parameter_columns=list(config.keys()),
metric_columns=[
"loss", f'tune_{task_info.metric_name}', "training_iteration"
]
),
resources_per_trial=resources_per_trial,
scheduler=ASHAScheduler(
max_t=args.max_epochs,
grace_period=args.min_epochs
),
queue_trials=True,
keep_checkpoints_num=3,
checkpoint_score_attr=f'tune_{task_info.metric_name}'
)
print("Best hyperparameters found were: ", analysis.best_config)

args_vars = vars(args)
args_vars.update(analysis.best_config)
model = model_class.load_from_checkpoint(
os.path.join(analysis.best_checkpoint, "tune.ckpt"), hparams=args, **model_kwargs
)
logger = loggers.TensorBoardLogger(
save_dir=analysis.best_trial.logdir, name="", version=".", default_hp_metric=False
)
trainer: Trainer = Trainer.from_argparse_args(args, logger=logger)
build_method(model, task_info)
trainer.test(model)
args_vars = vars(args)
args_vars.update(analysis.best_config)
model = model_class.load_from_checkpoint(
os.path.join(analysis.best_checkpoint, "tune.ckpt"), hparams=args, **model_kwargs
)
logger = loggers.TensorBoardLogger(
save_dir=analysis.best_trial.logdir, name="", version=".", default_hp_metric=False
)
trainer: Trainer = Trainer.from_argparse_args(args, logger=logger)
build_method(model, task_info)
trainer.test(model)

except Exception as e:
def tune_train(*args, **kwargs):
print("please install ray[tune]")

0 comments on commit e5eef92

Please sign in to comment.