From 59c0e69b823cfc18ec8cab388df8cc7d4d75d7b4 Mon Sep 17 00:00:00 2001 From: fzou1 Date: Wed, 8 Nov 2017 14:47:28 +0800 Subject: [PATCH] implement lars from UC Berkeley Change-Id: I5112a9de2d6357e303eea3501e202995dbdcb0bb --- include/caffe/sgd_solvers.hpp | 2 + .../alexnet_bn_32nodes/solver.prototxt | 47 +++ .../alexnet_bn_32nodes/train_val2.prototxt | 361 ++++++++++++++++++ .../alexnet_bn_64nodes/solver.prototxt | 47 +++ .../alexnet_bn_64nodes/train_val2.prototxt | 361 ++++++++++++++++++ src/caffe/proto/caffe.proto | 3 + src/caffe/solvers/sgd_solver.cpp | 46 ++- 7 files changed, 863 insertions(+), 4 deletions(-) create mode 100644 models/intel_optimized_models/multinode/alexnet_bn_32nodes/solver.prototxt create mode 100644 models/intel_optimized_models/multinode/alexnet_bn_32nodes/train_val2.prototxt create mode 100644 models/intel_optimized_models/multinode/alexnet_bn_64nodes/solver.prototxt create mode 100644 models/intel_optimized_models/multinode/alexnet_bn_64nodes/train_val2.prototxt diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp index 09f6ff26e..bef2be8ca 100644 --- a/include/caffe/sgd_solvers.hpp +++ b/include/caffe/sgd_solvers.hpp @@ -64,6 +64,8 @@ class SGDSolver : public Solver { void PreSolve(); Dtype GetWarmUpLR(int cur_iter, int warmup_iter, Dtype warmup_start_lr); Dtype GetLearningRate(); + Dtype GetLocalRate(int param_id) const; + virtual void ApplyUpdate(); virtual void ApplyUpdate(int param_id); virtual void Normalize(int param_id); diff --git a/models/intel_optimized_models/multinode/alexnet_bn_32nodes/solver.prototxt b/models/intel_optimized_models/multinode/alexnet_bn_32nodes/solver.prototxt new file mode 100644 index 000000000..b6c6690e1 --- /dev/null +++ b/models/intel_optimized_models/multinode/alexnet_bn_32nodes/solver.prototxt @@ -0,0 +1,47 @@ +net: "models/intel_optimized_models/multinode/alexnet_bn_32nodes/train_val2.prototxt" + +test_iter: 1000 # 196 =50000/256 # 1562 = 50000/32 +test_interval: 25 #1250 +test_initialization: false + +display: 10 + +max_iter: 4000 # 100 epochs + +#base_lr: 2 # B=1024 +#base_lr: 10 # B=8K +base_lr: 33 # B=32K + +#min_lr: 0.005 + +local_lr_auto: true +local_gw_ratio: 0.001 + +warmup_start_lr: 1 +warmup_iter: 400 # 10 epochs + +lr_policy: "poly" +power: 2. + +momentum: 0.9 +weight_decay: 0.0005 + +snapshot: 500000 +snapshot_prefix: "models/intel_optimized_models/multinode/alexnet_bn_32nodes/alexnet_bn_32nodes" +#snapshot_after_train: false + +solver_mode: CPU + +# Train dataset size = 1,281,167 +# Test dataset size = 50,000 + +# batch 64 --> epoch = 20,000 +# batch 96 --> epoch = 15,000 +# batch 128 --> epoch = 10,000 +# batch 256 --> epoch = 5,000 +# batch 512 --> epoch = 2,500 +# batch 1024--> epoch = 1,250 +# batch 2048--> epoch = 625 +# batch 4096--> epoch = 312 +# batch 8192--> epoch = 156 +# batch 16384--> epoch = 78 diff --git a/models/intel_optimized_models/multinode/alexnet_bn_32nodes/train_val2.prototxt b/models/intel_optimized_models/multinode/alexnet_bn_32nodes/train_val2.prototxt new file mode 100644 index 000000000..48046d17d --- /dev/null +++ b/models/intel_optimized_models/multinode/alexnet_bn_32nodes/train_val2.prototxt @@ -0,0 +1,361 @@ +#------------------------------------------ +# Alexnet wih BatchNorm instead of LRN +#------------------------------------------ +name: "AlexNet_bn" + +#default_conv_algos_override: "1,1,1" +#default_cudnn_math_override: 0 + +layer { + name: "data" + type: "Data" + top: "data" + top: "label" + transform_param { + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + crop_size: 227 + scale: 0.00390625 + } + data_param { + source: "examples/imagenet/ilsvrc12_train_lmdb" + batch_size: 1024 + backend: LMDB + # cache: true + shuffle: true + } + include { phase: TRAIN } +} +layer { + name: "data" + type: "Data" + top: "data" + top: "label" + transform_param { + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + crop_size: 227 + scale: 0.00390625 + } + data_param { + source: "examples/imagenet/ilsvrc12_val_lmdb" + batch_size: 50 + backend: LMDB + } + include { phase: TEST } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu1" + type: "ReLU" + bottom: "conv1" + top: "conv1" +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "bn2" + type: "BatchNorm" + bottom: "pool1" + top: "bn2" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "bn2" + top: "conv2" + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu2" + type: "ReLU" + bottom: "conv2" + top: "conv2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "bn3" + type: "BatchNorm" + bottom: "pool2" + top: "bn3" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} + +layer { + name: "conv3" + type: "Convolution" + bottom: "bn3" + top: "conv3" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu3" + type: "ReLU" + bottom: "conv3" + top: "conv3" +} +layer { + name: "bn4" + type: "BatchNorm" + bottom: "conv3" + top: "bn4" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} +layer { + name: "conv4" + type: "Convolution" + bottom: "bn4" + top: "conv4" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} + + +layer { + name: "relu4" + type: "ReLU" + bottom: "conv4" + top: "conv4" +} +layer { + name: "bn5" + type: "BatchNorm" + bottom: "conv4" + top: "bn5" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} +layer { + name: "conv5" + type: "Convolution" + bottom: "bn5" + top: "conv5" + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu5" + type: "ReLU" + bottom: "conv5" + top: "conv5" +} +layer { + name: "pool5" + type: "Pooling" + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "fc6" + type: "InnerProduct" + bottom: "pool5" + top: "fc6" + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layer { + name: "relu6" + type: "ReLU" + bottom: "fc6" + top: "fc6" +} +layer { + name: "drop6" + type: "Dropout" + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc7" + type: "InnerProduct" + bottom: "fc6" + top: "fc7" + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layer { + name: "relu7" + type: "ReLU" + bottom: "fc7" + top: "fc7" +} +layer { + name: "drop7" + type: "Dropout" + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc8" + type: "InnerProduct" + bottom: "fc7" + top: "fc8" + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "fc8" + bottom: "label" + top: "loss" +} +layer { + name: "top-1" + type: "Accuracy" + bottom: "fc8" + bottom: "label" + top: "accuracy/top-1" + accuracy_param { top_k: 1 } +# include { phase: TEST } +} +layer { + name: "top-5" + type: "Accuracy" + bottom: "fc8" + bottom: "label" + top: "accuracy/top-5" + accuracy_param { top_k: 5 } + include { phase: TEST } +} diff --git a/models/intel_optimized_models/multinode/alexnet_bn_64nodes/solver.prototxt b/models/intel_optimized_models/multinode/alexnet_bn_64nodes/solver.prototxt new file mode 100644 index 000000000..70506cb0d --- /dev/null +++ b/models/intel_optimized_models/multinode/alexnet_bn_64nodes/solver.prototxt @@ -0,0 +1,47 @@ +net: "models/intel_optimized_models/multinode/alexnet_bn_64nodes/train_val2.prototxt" + +test_iter: 196 # 196 =50000/256 # 1562 = 50000/32 +test_interval: 40 #1250 +test_initialization: false + +display: 10 + +max_iter: 4000 # 100 epochs + +#base_lr: 2 # B=1024 +#base_lr: 10 # B=8K +base_lr: 32 # B=32K + +#min_lr: 0.005 + +local_lr_auto: true +local_gw_ratio: 0.001 + +warmup_start_lr: 1 +warmup_iter: 400 # 10 epochs + +lr_policy: "poly" +power: 2. + +momentum: 0.9 +weight_decay: 0.0005 + +snapshot: 500000 +snapshot_prefix: "models/intel_optimized_models/multinode/alexnet_bn_64nodes/alexnet_bn_64nodes" +#snapshot_after_train: false + +solver_mode: CPU + +# Train dataset size = 1,281,167 +# Test dataset size = 50,000 + +# batch 64 --> epoch = 20,000 +# batch 96 --> epoch = 15,000 +# batch 128 --> epoch = 10,000 +# batch 256 --> epoch = 5,000 +# batch 512 --> epoch = 2,500 +# batch 1024--> epoch = 1,250 +# batch 2048--> epoch = 625 +# batch 4096--> epoch = 312 +# batch 8192--> epoch = 156 +# batch 16384--> epoch = 78 diff --git a/models/intel_optimized_models/multinode/alexnet_bn_64nodes/train_val2.prototxt b/models/intel_optimized_models/multinode/alexnet_bn_64nodes/train_val2.prototxt new file mode 100644 index 000000000..f02068fcc --- /dev/null +++ b/models/intel_optimized_models/multinode/alexnet_bn_64nodes/train_val2.prototxt @@ -0,0 +1,361 @@ +#------------------------------------------ +# Alexnet wih BatchNorm instead of LRN +#------------------------------------------ +name: "AlexNet_bn" + +#default_conv_algos_override: "1,1,1" +#default_cudnn_math_override: 0 + +layer { + name: "data" + type: "Data" + top: "data" + top: "label" + transform_param { + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + crop_size: 227 + scale: 0.00390625 + } + data_param { + source: "examples/imagenet/ilsvrc12_train_lmdb" + batch_size: 512 + backend: LMDB + # cache: true + shuffle: true + } + include { phase: TRAIN } +} +layer { + name: "data" + type: "Data" + top: "data" + top: "label" + transform_param { + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + crop_size: 227 + scale: 0.00390625 + } + data_param { + source: "examples/imagenet/ilsvrc12_val_lmdb" + batch_size: 256 + backend: LMDB + } + include { phase: TEST } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu1" + type: "ReLU" + bottom: "conv1" + top: "conv1" +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "bn2" + type: "BatchNorm" + bottom: "pool1" + top: "bn2" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "bn2" + top: "conv2" + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu2" + type: "ReLU" + bottom: "conv2" + top: "conv2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "bn3" + type: "BatchNorm" + bottom: "pool2" + top: "bn3" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} + +layer { + name: "conv3" + type: "Convolution" + bottom: "bn3" + top: "conv3" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu3" + type: "ReLU" + bottom: "conv3" + top: "conv3" +} +layer { + name: "bn4" + type: "BatchNorm" + bottom: "conv3" + top: "bn4" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} +layer { + name: "conv4" + type: "Convolution" + bottom: "bn4" + top: "conv4" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} + + +layer { + name: "relu4" + type: "ReLU" + bottom: "conv4" + top: "conv4" +} +layer { + name: "bn5" + type: "BatchNorm" + bottom: "conv4" + top: "bn5" + batch_norm_param { + moving_average_fraction: 0.9 + eps: 0.0001 + # scale_bias: false + } +} +layer { + name: "conv5" + type: "Convolution" + bottom: "bn5" + top: "conv5" + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "relu5" + type: "ReLU" + bottom: "conv5" + top: "conv5" +} +layer { + name: "pool5" + type: "Pooling" + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "fc6" + type: "InnerProduct" + bottom: "pool5" + top: "fc6" + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layer { + name: "relu6" + type: "ReLU" + bottom: "fc6" + top: "fc6" +} +layer { + name: "drop6" + type: "Dropout" + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc7" + type: "InnerProduct" + bottom: "fc6" + top: "fc7" + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layer { + name: "relu7" + type: "ReLU" + bottom: "fc7" + top: "fc7" +} +layer { + name: "drop7" + type: "Dropout" + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc8" + type: "InnerProduct" + bottom: "fc7" + top: "fc8" + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.01 + } + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "fc8" + bottom: "label" + top: "loss" +} +layer { + name: "top-1" + type: "Accuracy" + bottom: "fc8" + bottom: "label" + top: "accuracy/top-1" + accuracy_param { top_k: 1 } +# include { phase: TEST } +} +layer { + name: "top-5" + type: "Accuracy" + bottom: "fc8" + bottom: "label" + top: "accuracy/top-5" + accuracy_param { top_k: 5 } + include { phase: TEST } +} diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 1d4608dec..a1f6e4d74 100755 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -415,6 +415,9 @@ message SolverParameter { optional string engine = 47 [default = ""]; optional int32 warmup_iter = 48 [default = 0]; optional float warmup_start_lr = 49 [default = 0]; + + optional bool local_lr_auto = 50 [default = false]; + optional float local_gw_ratio = 51 [default = 0.001]; } // A message that stores the solver snapshots diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index 4b643c495..4142fef35 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -336,9 +336,8 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; //ComputeUpdateValue initialization - const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_rate = rate * GetLocalRate(param_id); //#pragma endregion //#pragma region 2. Common condition judgement @@ -579,9 +578,8 @@ void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum, template void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_rate = rate * GetLocalRate(param_id); if (this->param_.warmup_iter() > 0 && this->iter_ < this->param_.warmup_iter()) { @@ -636,6 +634,46 @@ void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +// +// LARS (Layer-wise Adaptive Rate Scaling) is implemented by Yang You, Ignor Gitman and Boris Ginsburg in UC Berkeley. +// please refer to the papers below: +// Scaling SGD Batch Size to 32K for ImageNet Training (https://www2.eecs.berkeley.edu/Pubs/TechRpts/2017/EECS-2017-149.html). +// Large Batch Training of Convolutional Networks (https://arxiv.org/abs/1708.03888). +template +Dtype SGDSolver::GetLocalRate(int param_id) const { + const vector& net_params_lr = this->net_->params_lr(); + float local_lr = net_params_lr[param_id]; + + if (this->param_.local_lr_auto()) { + Blob* param = this->net_->learnable_params()[param_id]; + const float w_norm = std::sqrt(param->sumsq_data()); + const float wgrad_norm = std::sqrt(param->sumsq_diff()); + const float gw_ratio = this->param_.local_gw_ratio(); + float rate = 1.F; + + float weight_decay = this->param_.weight_decay(); + if (w_norm > 0.F && wgrad_norm > 0.F) { + rate = gw_ratio * w_norm / (wgrad_norm + weight_decay * w_norm); + } + if (local_lr > 0.F) { + local_lr = rate; + } + +#ifdef DEBUG + if (Caffe::root_solver() + && this->param_.display() + && (this->iter_ % this->param_.display() == 0)) { + const int layer_id = this->net_->param_layer_indices(param_id).first; + const string& layer_name = this->net_->layer_names()[layer_id]; + const int blob_id = this->net_->param_layer_indicces(param_id).second; + LOG(INFO) << layer_name << "." << blob_id << " lr=" << local_lr + << ".\t w=" << w_norm << "\t dw=" << wgrad_norm; + } +#endif + } + return local_lr; +} + template void SGDSolver::SnapshotSolverState(const string& model_filename) { switch (this->param_.snapshot_format()) {