Skip to content

Commit

Permalink
ADD main
Browse files Browse the repository at this point in the history
  • Loading branch information
sminy67 committed Nov 5, 2022
1 parent 86261f6 commit 055ad08
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# coding=utf-8
from __future__ import absolute_import, division, print_function

import logging
import argparse
import os
import random
import numpy as np

import wandb
import torch
import torch.nn as nn

from utils.data_utils import get_loader

from models.mlp_mixer import MlpMixer, CONFIGS
from models import configs
from utils.infer import Inference
from utils.train import Trainer
from admm import AdmmTrainer

def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

def set_model(args):
config = CONFIGS[args.model_type]

if args.dataset == "cifar_10":
args.num_classes = 10
elif args.dataset == "cifar_100":
args.num_classes = 100

model = MlpMixer(config, args.img_size, num_classes=args.num_classes, patch_size=16, zero_head=False)
#model.load_state_dict(torch.load(os.path.join("saved_models/pretrained_models", args.model_type + ".pt")))
model.load_state_dict(torch.load(os.path.join("saved_models/warmup_models", args.name + '.pt')))
#model.to(args.device)

return model, args

def main():
parser = argparse.ArgumentParser()

# MODEL and DATASET
parser.add_argument("--model-type", choices=["Mixer-B_16", "Mixer-L_16"],
default="Mixer-B_16", help="Which model to use")
parser.add_argument("--dataset", choices=["cifar_10", "cifar_100"],
default="cifar_10", help="Which dataset to use")
parser.add_argument("--train-batch-size", default=64, type=int,
help="Batch size for Training")
parser.add_argument("--test-batch-size", default=128,
help="Batch size for Testing")
parser.add_argument("--seed", default=77, type=int,
help="Random seed for initialization")

# CONFIGS of training
parser.add_argument("--train-type", choices=["original", "tt_format", "pruning", "admm"],
default="original", help="Defining training type")
parser.add_argument("--epochs", default=50, type=int,
help="Number of epochs for training")
parser.add_argument("--use-adam", default=0, type=int,
help="Use ADAM for optimizer")
parser.add_argument("--weight-decay", default=0, type=float,
help="Weight decay if we apply some")
parser.add_argument("--lr-schedular", choices=["cosine", "exp"], default="cosine",
help="How to decay the learning rate.")
parser.add_argument("--learning-rate", default=5e-2, type=float,
help="The initial learning rate for optimizer")
parser.add_argument("--max-grad-norm", default=1.0, type=float,
help="Max gradient norm")

# CONFIGS of ADMM training or TT training
parser.add_argument("--target-layer", default=[0,1,2,3,4,5,6,7,8,9,10,11,12], nargs='+', type=int,
help="Indexes of layers that are applied with TT-format or Pruning")
parser.add_argument("--warmup-training", default=0, type=int,
help="Begin with warmup training before ADMM training")
parser.add_argument("--warmup-epochs", default=5, type=int,
help="Number of epochs for warmup training")
parser.add_argument("--admm-epochs", default=30, type=int,
help="Number of epochs for ADMM training")
parser.add_argument('--rho', type=float, default=1e-1,
help='cardinality weight (default: 1e-2)')
parser.add_argument('--alpha', type=float, default=5e-4, metavar='L',
help='l2 norm weight (default: 5e-4)')
parser.add_argument('--l2', default=False, action='store_true',
help='apply l2 regularization')
parser.add_argument("--tt-ranks", default=[16, 16], type=list,
help="TT-ranks for TT-decomposition")
parser.add_argument("--hidden-tt-shape", default=[8, 8, 12], type=list,
help="Factorized hidden dim shape for TT-format")
parser.add_argument("--channels-tt-shape", default=[12, 16, 16], type=list,
help="Factorized channel dim shape for TT-format")

# CONFIGS of testing
parser.add_argument("--inference-only", default=0, type=int,
help="Test model mode")
parser.add_argument("--test-type", choices=["original", "tt_format", "pruning", "admm"],
default="original", help="Defining training type")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = '1'

args.name = args.model_type + '_' + args.dataset
args.n_gpu = torch.cuda.device_count()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.img_size = 224

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)

# Set seed
set_seed(args)

if args.inference_only:
infer = Inference(args)

if args.train_type != "admm":
train = Trainer(args)
else:
train = AdmmTrainer(args)

if __name__=="__main__":
main()

0 comments on commit 055ad08

Please sign in to comment.