From 5e65df0699f4fabb077a75dcb0d6c0893e5cdfa6 Mon Sep 17 00:00:00 2001 From: Piotr Dollar Date: Mon, 3 May 2021 15:02:26 -0700 Subject: [PATCH] Model scaling in "Fast and Accurate Model Scaling" (#137) Summary: See GETTING_STARTED.md for example usage. -paper reference: https://arxiv.org/abs/2103.06877 -regnet.py: added regnet_cfg_to_anynet_cfg() -scaler.py: implements model scaler -scale_net.py: entry point for model scaler -GETTING_STARTED.md: added example usage for scaler Pull Request resolved: https://github.com/facebookresearch/pycls/pull/137 Reviewed By: theschnitz, vaibhava0 Differential Revision: D28154580 Pulled By: pdollar fbshipit-source-id: 7d015cba588a167b0f51c3cfb9564f76e25604de --- docs/GETTING_STARTED.md | 13 ++++++ pycls/core/config.py | 5 +++ pycls/models/regnet.py | 40 ++++++++++++++---- pycls/models/scaler.py | 91 +++++++++++++++++++++++++++++++++++++++++ tools/scale_net.py | 41 +++++++++++++++++++ 5 files changed, 181 insertions(+), 9 deletions(-) create mode 100644 pycls/models/scaler.py create mode 100644 tools/scale_net.py diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 667194f..a97cf63 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -96,3 +96,16 @@ python tools/time_net.py PREC_TIME.WARMUP_ITER 5 \ PREC_TIME.NUM_ITER 50 ``` + +### MODEL SCALING + +Scale a RegNetY-4GF by 4x using fast compound scaling (see https://arxiv.org/abs/2103.06877): + +``` +python tools/scale_net.py \ + --cfg configs/dds_baselines/regnety/RegNetY-4.0GF_dds_8gpu.yaml \ + OUT_DIR ./ \ + CFG_DEST "RegNetY-4.0GF_dds_8gpu_scaled.yaml" \ + MODEL.SCALING_FACTOR 4.0 \ + MODEL.SCALING_TYPE "d1_w8_g8_r1" +``` diff --git a/pycls/core/config.py b/pycls/core/config.py index 4844324..48e0cdf 100644 --- a/pycls/core/config.py +++ b/pycls/core/config.py @@ -41,6 +41,10 @@ # Perform activation inplace if implemented _C.MODEL.ACTIVATION_INPLACE = True +# Model scaling parameters, see models/scaler.py (has no effect unless scaler is used) +_C.MODEL.SCALING_TYPE = "" +_C.MODEL.SCALING_FACTOR = 1.0 + # ---------------------------------- ResNet options ---------------------------------- # _C.RESNET = CfgNode() @@ -388,6 +392,7 @@ def dump_cfg(): cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) with pathmgr.open(cfg_file, "w") as f: _C.dump(stream=f) + return cfg_file def load_cfg(cfg_file): diff --git a/pycls/models/regnet.py b/pycls/models/regnet.py index 36c5d30..8902432 100644 --- a/pycls/models/regnet.py +++ b/pycls/models/regnet.py @@ -31,20 +31,42 @@ def generate_regnet(w_a, w_0, w_m, d, q=8): return ws, ds, num_stages, total_stages, ws_all, ws_cont +def generate_regnet_full(): + """Generates per stage ws, ds, gs, bs, and ss from RegNet cfg.""" + w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH + ws, ds = generate_regnet(w_a, w_0, w_m, d)[0:2] + ss = [cfg.REGNET.STRIDE for _ in ws] + bs = [cfg.REGNET.BOT_MUL for _ in ws] + gs = [cfg.REGNET.GROUP_W for _ in ws] + ws, bs, gs = bk.adjust_block_compatibility(ws, bs, gs) + return ws, ds, ss, bs, gs + + +def regnet_cfg_to_anynet_cfg(): + """Convert RegNet cfg to AnyNet cfg format (note: alters global cfg).""" + assert cfg.MODEL.TYPE == "regnet" + ws, ds, ss, bs, gs = generate_regnet_full() + cfg.MODEL.TYPE = "anynet" + cfg.ANYNET.STEM_TYPE = cfg.REGNET.STEM_TYPE + cfg.ANYNET.STEM_W = cfg.REGNET.STEM_W + cfg.ANYNET.BLOCK_TYPE = cfg.REGNET.BLOCK_TYPE + cfg.ANYNET.DEPTHS = ds + cfg.ANYNET.WIDTHS = ws + cfg.ANYNET.STRIDES = ss + cfg.ANYNET.BOT_MULS = bs + cfg.ANYNET.GROUP_WS = gs + cfg.ANYNET.HEAD_W = cfg.REGNET.HEAD_W + cfg.ANYNET.SE_ON = cfg.REGNET.SE_ON + cfg.ANYNET.SE_R = cfg.REGNET.SE_R + + class RegNet(AnyNet): """RegNet model.""" @staticmethod def get_params(): - """Convert RegNet to AnyNet parameter format.""" - # Generates per stage ws, ds, gs, bs, and ss from RegNet parameters - w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH - ws, ds = generate_regnet(w_a, w_0, w_m, d)[0:2] - ss = [cfg.REGNET.STRIDE for _ in ws] - bs = [cfg.REGNET.BOT_MUL for _ in ws] - gs = [cfg.REGNET.GROUP_W for _ in ws] - ws, bs, gs = bk.adjust_block_compatibility(ws, bs, gs) - # Get AnyNet arguments defining the RegNet + """Get AnyNet parameters that correspond to the RegNet.""" + ws, ds, ss, bs, gs = generate_regnet_full() return { "stem_type": cfg.REGNET.STEM_TYPE, "stem_w": cfg.REGNET.STEM_W, diff --git a/pycls/models/scaler.py b/pycls/models/scaler.py new file mode 100644 index 0000000..112ba69 --- /dev/null +++ b/pycls/models/scaler.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Model scaler for scaling strategies in https://arxiv.org/abs/2103.06877.""" + +from math import isclose + +import pycls.models.regnet as regnet +from pycls.core.config import cfg +from pycls.models.blocks import adjust_block_compatibility + + +def scaling_factors(scale_type, scale_factor): + """ + Computes model scaling factors to allow for scaling along d, w, g, r. + + Compute scaling factors such that d * w * w * r * r == scale_factor. + Here d is depth, w is width, g is groups, and r is resolution. + Note that scaling along g is handled in a special manner (see paper or code). + + Examples of scale_type include "d", "dw", "d1_w2", and "d1_w2_g2_r0". + A scale_type of the form "dw" is equivalent to "d1_w1_g0_r0". The scalar value + after each scaling dimensions gives the relative scaling along that dimension. + For example, "d1_w2" indicates to scale twice more along width than depth. + Finally, scale_factor indicates the absolute amount of scaling. + + The "fast compound scaling" strategy from the paper is specified via "d1_w8_g8_r1". + """ + if all(s in "dwgr" for s in scale_type): + weights = {s: 1.0 if s in scale_type else 0.0 for s in "dwgr"} + else: + weights = {sw[0]: float(sw[1::]) for sw in scale_type.split("_")} + weights = {**{s: 0.0 for s in "dwgr"}, **weights} + assert all(s in "dwgr" for s in weights.keys()), scale_type + sum_weights = weights["d"] + weights["w"] + weights["r"] or weights["g"] / 2 or 1.0 + d = scale_factor ** (weights["d"] / sum_weights) + w = scale_factor ** (weights["w"] / sum_weights / 2.0) + g = scale_factor ** (weights["g"] / sum_weights / 2.0) + r = scale_factor ** (weights["r"] / sum_weights / 2.0) + s_actual = d * w * w * r * r + assert d == w == r == 1.0 or isclose(s_actual, scale_factor, rel_tol=0.01) + return d, w, g, r + + +def scale_model(): + """ + Scale model blocks by the specified type and amount (note: alters global cfg). + + The actual scaling is specified by MODEL.SCALING_TYPE and MODEL.SCALING_FACTOR. + + Note that the scaler must be employed on a standalone config outside of the main + training loop. This is because it alters the global config, which is typically + frozen during training. So one should use this function to generate a new config and + save it to a file, and then evoke training separately on the new config. + """ + assert cfg.MODEL.TYPE in ["anynet", "effnet", "regnet"] + # Get scaling factors + scale_type, scale = cfg.MODEL.SCALING_TYPE, cfg.MODEL.SCALING_FACTOR + d_scale, w_scale, g_scale, r_scale = scaling_factors(scale_type, scale) + if cfg.MODEL.TYPE == "regnet": + # Convert a RegNet to an AnyNet prior to scaling + regnet.regnet_cfg_to_anynet_cfg() + if cfg.MODEL.TYPE == "anynet": + # Scale AnyNet + an = cfg.ANYNET + ds, ws, bs, gs = an.DEPTHS, an.WIDTHS, an.BOT_MULS, an.GROUP_WS + bs = bs if bs else [1] * len(ds) + gs = gs if gs else [1] * len(ds) + ds = [max(1, round(d * d_scale)) for d in ds] + ws = [max(1, round(w * w_scale / 8)) * 8 for w in ws] + gs = [max(1, round(g * g_scale)) for g in gs] + gs = [g if g <= 2 else 4 if g <= 5 else round(g / 8) * 8 for g in gs] + ws, bs, gs = adjust_block_compatibility(ws, bs, gs) + an.DEPTHS, an.WIDTHS, an.BOT_MULS, an.GROUP_WS = ds, ws, bs, gs + elif cfg.MODEL.TYPE == "effnet": + # Scale EfficientNet + en = cfg.EN + ds, ws, bs, sw, hw = en.DEPTHS, en.WIDTHS, en.EXP_RATIOS, en.STEM_W, en.HEAD_W + ds = [max(1, round(d * d_scale)) for d in ds] + ws = [max(1, round(w * w_scale / 8)) * 8 for w in ws] + sw = max(1, round(sw * w_scale / 8)) * 8 + hw = max(1, round(hw * w_scale / 8)) * 8 + ws, bs, _ = adjust_block_compatibility(ws, bs, [1] * len(ds)) + en.DEPTHS, en.WIDTHS, en.EXP_RATIOS, en.STEM_W, en.HEAD_W = ds, ws, bs, sw, hw + # Scale image resolution + cfg.TRAIN.IM_SIZE = round(cfg.TRAIN.IM_SIZE * r_scale / 4) * 4 + cfg.TEST.IM_SIZE = round(cfg.TEST.IM_SIZE * r_scale / 4) * 4 diff --git a/tools/scale_net.py b/tools/scale_net.py new file mode 100644 index 0000000..0661672 --- /dev/null +++ b/tools/scale_net.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Scale a model using scaling strategies from "Fast and Accurate Model Scaling". + +For reference on scaling strategies, see: https://arxiv.org/abs/2103.06877. +For example usage, see GETTING_STARTED, MODEL SCALING section. +For implementation details, see pycls/models/scaler.py. + +This function takes a config as input, scaled the model in the config, and saves the +config for the scaled model back to disk (to OUT_DIR/CFG_DEST). The typical params in +the config that need to specified when scaling a model are MODEL.SCALING_FACTOR and +MODEL.SCALING_TYPE. For the SCALING_TYPE, "d1_w8_g8_r1" gives the fast compound scaling +and is the likely best default option. For further details, see pycls/models/scaler.py. +""" + +import pycls.core.builders as builders +import pycls.core.config as config +import pycls.core.net as net +import pycls.models.scaler as scaler + + +def main(): + config.load_cfg_fom_args("Scale a model.") + config.assert_and_infer_cfg() + cx_orig = net.complexity(builders.get_model()) + scaler.scale_model() + cx_scaled = net.complexity(builders.get_model()) + cfg_file = config.dump_cfg() + print("Scaled config dumped to:", cfg_file) + print("Original model complexity:", cx_orig) + print("Scaled model complexity:", cx_scaled) + + +if __name__ == "__main__": + main()