Skip to content

Commit

Permalink
fix: bad re's for collect_params().initialize(); add optimizer param
Browse files Browse the repository at this point in the history
  • Loading branch information
breezedeus committed May 19, 2020
1 parent fa272b1 commit ffe09fe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 41 deletions.
28 changes: 17 additions & 11 deletions cnstr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def cli():
@click.option('-i', '--data_dir', type=str, help='数据所在的根目录')
@click.option('--pretrain_model_fp', type=str, default=None, help='初始化模型路径')
@click.option('--gpu', type=int, default=-1, help='使用的GPU数量。默认值为-1,表示自动判断')
@click.option('--batch_size', type=int, default=4, help='batch size for each device [Default: 4]')
@click.option(
"--optimizer",
type=str,
default='Adam',
help="optimizer for training [Default: Adam]",
)
@click.option(
'--batch_size', type=int, default=4, help='batch size for each device [Default: 4]'
)
@click.option('--epoch', type=int, default=50, help='train epochs [Default: 50]')
@click.option('--lr', type=float, default=0.001, help='learning rate [Default: 0.001]')
@click.option('--momentum', type=float, default=0.9, help='momentum [Default: 0.9]')
@click.option('--momentum', type=float, default=0.99, help='momentum [Default: 0.9]')
@click.option(
'--wd', type=float, default=5e-4, help='weight decay factor [Default: 0.0]'
)
Expand All @@ -34,6 +42,7 @@ def train_model(
data_dir,
pretrain_model_fp,
gpu,
optimizer,
batch_size,
epoch,
lr,
Expand All @@ -49,30 +58,27 @@ def train_model(
train(
data_dir=data_dir,
pretrain_model=pretrain_model_fp,
ctx=devices,
optimizer=optimizer,
batch_size=batch_size,
epochs=epoch,
lr=lr,
momentum=momentum,
wd=wd,
verbose_step=log_step,
ckpt=output_dir,
ctx=devices,
)


@cli.command('evaluate', context_settings=_CONTEXT_SETTINGS)
@click.option('-i', '--data_dir', type=str, help='数据所在的根目录')
@click.option('--model_fp', type=str, default=None, help='模型路径')
@click.option('--gpu', type=int, default=-1, help='使用的GPU数量。默认值为-1,表示自动判断')
@click.option('--batch_size', type=int, default=4, help='batch size for each device [Default: 4]')
@click.option(
'--batch_size', type=int, default=4, help='batch size for each device [Default: 4]'
)
@click.option('-o', '--output_dir', default='ckpt', help='模型输出的目录')
def evaluate_model(
data_dir,
model_fp,
gpu,
batch_size,
output_dir,
):
def evaluate_model(data_dir, model_fp, gpu, batch_size, output_dir):
devices = gen_context(gpu)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand Down
72 changes: 42 additions & 30 deletions cnstr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mxnet as mx
from mxnet.gluon.data import DataLoader
from mxnet.gluon import Trainer
from mxnet import autograd, gluon, lr_scheduler as ls
from mxnet import autograd, lr_scheduler as ls
from tensorboardX import SummaryWriter

from .utils import to_cpu, split_and_load
Expand All @@ -20,6 +20,7 @@
def train(
data_dir,
pretrain_model,
optimizer,
epochs=50,
lr=0.001,
wd=5e-4,
Expand All @@ -30,35 +31,46 @@ def train(
ckpt='ckpt',
):
num_kernels = 3
icdar_loader = ICDAR(root_dir=data_dir, num_kernels=num_kernels - 1)
icdar_ds = ICDAR(root_dir=data_dir, num_kernels=num_kernels - 1)
if not isinstance(ctx, (list, tuple)):
ctx = [ctx]
batch_size = batch_size * len(ctx)
loader = DataLoader(icdar_loader, batch_size=batch_size, shuffle=True)
loader = DataLoader(icdar_ds, batch_size=batch_size, shuffle=True)
net = PSENet(num_kernels=num_kernels, ctx=ctx, pretrained=True)
# initial params
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
net.collect_params("extra_*_weight | decoder_*_weight").initialize(
net.initialize(mx.init.Xavier(), ctx=ctx)
net.collect_params("extra_.*_weight|decoder_.*_weight").initialize(
mx.init.Xavier(), ctx=ctx, force_reinit=True
)
net.collect_params("extra_*_bias | decoder_*_bias").initialize(
net.collect_params("extra_.*_bias|decoder_.*_bias").initialize(
mx.init.Zero(), ctx=ctx, force_reinit=True
)
net.collect_params("!(resnet*)").setattr("lr_mult", 10)
net.collect_params("!(resnet*)").setattr('grad_req', 'null')
# net.collect_params("!(resnet*)").setattr("lr_mult", 10)
# net.collect_params("!(resnet*)").setattr('grad_req', 'null')
net.load_parameters(pretrain_model, ctx=ctx, allow_missing=True, ignore_extra=True)
# net.initialize(ctx=ctx)

# pse_loss = DiceLoss(lam=0.7, num_kernels=num_kernels)
pse_loss = DiceLoss_with_OHEM(lam=0.7, num_kernels=num_kernels, debug=False)

cos_shc = ls.PolyScheduler(
max_update=icdar_loader.length * epochs // batch_size, base_lr=lr
# lr_scheduler = ls.PolyScheduler(
# max_update=icdar_loader.length * epochs // batch_size, base_lr=lr
# )
max_update = len(icdar_ds) * epochs // batch_size
lr_scheduler = ls.MultiFactorScheduler(
base_lr=lr, step=[max_update // 3, max_update * 2 // 3], factor=0.1
)

optimizer_params = {
'learning_rate': lr,
'wd': wd,
'momentum': momentum,
'lr_scheduler': lr_scheduler,
}
if optimizer.lower() == 'adam':
optimizer_params.pop('momentum')

trainer = Trainer(
net.collect_params(),
'sgd',
{'learning_rate': lr, 'wd': wd, 'momentum': momentum, 'lr_scheduler': cos_shc},
net.collect_params(), optimizer=optimizer, optimizer_params=optimizer_params
)
summary_writer = SummaryWriter(ckpt)
for e in range(epochs):
Expand All @@ -67,19 +79,15 @@ def train(
num_batches = 0
for i, item in enumerate(loader):
item_ctxs = [split_and_load(field, ctx) for field in item]
# item_ctxs = split_and_load(item, ctx)
loss_list = []
for im, score_maps, kernels, training_masks, ori_img in zip(*item_ctxs):
# im, score_maps, kernels, training_masks, ori_img = item
# im = im.as_in_context(ctx)
# import pdb; pdb.set_trace()
score_maps = score_maps[:, ::4, ::4]
kernels = kernels[:, :, ::4, ::4]
for im, gt_text, gt_kernels, training_masks, ori_img in zip(*item_ctxs):
gt_text = gt_text[:, ::4, ::4]
gt_kernels = gt_kernels[:, :, ::4, ::4]
training_masks = training_masks[:, ::4, ::4]

with autograd.record():
kernels_pred = net(im)
loss = pse_loss(score_maps, kernels, kernels_pred, training_masks)
kernels_pred = net(im) # 第0个是对complete text的预测
loss = pse_loss(gt_text, gt_kernels, kernels_pred, training_masks)
loss_list.append(loss)
mean_loss = []
for loss in loss_list:
Expand All @@ -88,22 +96,24 @@ def train(
mean_loss = np.mean(mean_loss)
trainer.step(batch_size)
if i % verbose_step == 0:
global_steps = icdar_loader.length * e + i * batch_size
global_steps = icdar_ds.length * e + i * batch_size
summary_writer.add_image(
'score_map', to_cpu(score_maps[0:1, :, :]), global_steps
'gt_text', to_cpu(gt_text[0, :, :]), global_steps
)
summary_writer.add_image(
'score_map_pred', to_cpu(kernels_pred[0:1, -1, :, :]), global_steps
'text_pred', to_cpu(kernels_pred[0, 0, :, :]), global_steps
)
summary_writer.add_image(
'kernel_map', to_cpu(kernels[0:1, 0, :, :]), global_steps
'gt_kernels[0]', to_cpu(gt_kernels[0, 0, :, :]), global_steps
)
summary_writer.add_image(
'kernel_map_pred', to_cpu(kernels_pred[0:1, 0, :, :]), global_steps
'kernels[0]_pred', to_cpu(kernels_pred[0, 1, :, :]), global_steps
)
summary_writer.add_scalar('loss', mean_loss, global_steps)
summary_writer.add_scalar(
'c_loss', mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(), global_steps
'c_loss',
mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(),
global_steps,
)
summary_writer.add_scalar(
'kernel_loss',
Expand All @@ -126,6 +136,8 @@ def train(
)
cumulative_loss += mean_loss
num_batches += 1
logger.info("Epoch {}, mean loss: {}\n".format(e, cumulative_loss / num_batches))
logger.info(
"Epoch {}, mean loss: {}\n".format(e, cumulative_loss / num_batches)
)
net.save_parameters(os.path.join(ckpt, 'model_{}.param'.format(e)))
summary_writer.close()

0 comments on commit ffe09fe

Please sign in to comment.