Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#195 from shunxing1234/dev_bmtrain
Browse files Browse the repository at this point in the history
add vit bmtrain train demo
  • Loading branch information
ftgreat authored Jan 6, 2023
2 parents f0ee4a4 + 98df65f commit 48678fc
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
80 changes: 80 additions & 0 deletions examples/vit_cifar100/train_bmtrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR100
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from flagai.trainer import Trainer
from flagai.auto_model.auto_loader import AutoLoader

lr = 2e-3
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env_type = "bmtrain"
trainer = Trainer(
env_type=env_type,
experiment_name="vit-cifar100-deepspeed",
batch_size=64,
num_gpus=2,
gradient_accumulation_steps=1,
lr=lr,
warm_up=0.001,
weight_decay=1e-5,
epochs=n_epochs,
log_interval=100,
eval_interval=1000,
load_dir=None,
pytorch_device=device,
save_dir="checkpoints_vit_cifar100_deepspeed",
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
deepspeed_config='./deepspeed.json',
training_script=__file__
)

def build_cifar():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Resize(224),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
return train_dataset, test_dataset

def collate_fn(batch):
images = torch.stack([b[0] for b in batch])
if trainer.fp16:
images = images.half()
labels = [b[1] for b in batch]
labels = torch.tensor(labels).long()
return {"images": images, "labels": labels}

def validate(logits, labels, meta=None):
_, predicted = logits.max(1)
total = labels.size(0)
correct = predicted.eq(labels).sum().item()
return correct / total

if __name__ == '__main__':
loader = AutoLoader(task_name="classification",
model_name="vit-base-p16-224",
num_classes=100)

model = loader.get_model()
train_dataset, val_dataset = build_cifar()

trainer.train(model,
train_dataset=train_dataset,
valid_dataset=val_dataset,
metric_methods=[["accuracy", validate]],
collate_fn=collate_fn)
2 changes: 1 addition & 1 deletion flagai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def train_step_bmtrain(self,
lm_loss = bmt.sum_loss(loss).item()
self.timers('backward').start()
try:
optim_manager.backward()
optim_manager.backward(loss)
except:
loss.backward()
self.timers('backward').stop()
Expand Down

0 comments on commit 48678fc

Please sign in to comment.