forked from FlagAI-Open/FlagAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request FlagAI-Open#58 from baai-open-internal/swinv1v2_ch…
…eckpoint_activations add swinv1v2
- Loading branch information
Showing
12 changed files
with
1,663 additions
and
16 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+149 KB
examples/swinv1/imagenet2012/val/n13044778/ILSVRC2012_val_00027938.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import torch | ||
import torchvision.datasets as datasets | ||
from flagai.auto_model.auto_loader import AutoLoader | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
data_path = "./imagenet2012/" | ||
|
||
loader = AutoLoader(task_name="classification", | ||
model_name='swinv1-base-patch4-window7-224', | ||
num_classes=1000) | ||
model = loader.get_model() | ||
model.eval() | ||
model = model.to(device) | ||
|
||
def data_loader(root, batch_size=64, workers=8): | ||
valdir = os.path.join(root, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
|
||
val_dataset = datasets.ImageFolder( | ||
valdir, | ||
transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
|
||
val_loader = DataLoader(val_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=workers | ||
) | ||
|
||
return val_loader | ||
|
||
# 测试预训练权重 | ||
@torch.no_grad() | ||
def test(model,data_loader): | ||
|
||
model.eval() | ||
top1_acc = 0.0 | ||
|
||
for step, (inputs, labels) in enumerate(data_loader): | ||
|
||
inputs, labels = inputs.to(device), labels.to(device) | ||
outputs = model(images=inputs)["logits"] | ||
|
||
_, top1_preds = outputs.max(1) | ||
top1_acc += top1_preds.eq(labels).sum().item() | ||
|
||
|
||
print( | ||
"test_top1_acc [{top1_acc}] \n".format( | ||
top1_acc=top1_acc/len(data_loader.dataset), | ||
) | ||
) | ||
|
||
if __name__ == '__main__': | ||
|
||
val_loader = data_loader(data_path) | ||
test(model,val_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
import torch | ||
import torchvision.datasets as datasets | ||
from flagai.auto_model.auto_loader import AutoLoader | ||
from torchvision import transforms | ||
from flagai.trainer import Trainer | ||
|
||
data_path = "./imagenet2012/" | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
trainer = Trainer(env_type="pytorch", | ||
epochs=10, | ||
experiment_name="swinv1_imagenet", | ||
batch_size=16, | ||
weight_decay=1e-3, | ||
warm_up=0.1, | ||
lr=5e-5, | ||
save_interval=1000, | ||
eval_interval=1000, | ||
log_interval=10, | ||
num_gpus=1) | ||
|
||
loader = AutoLoader(task_name="classification", | ||
model_name='swinv1-base-patch4-window7-224', | ||
num_classes=1000) | ||
model = loader.get_model() | ||
|
||
# build imagenet dataset | ||
def build_dataset(root): | ||
traindir = os.path.join(root, 'train') | ||
valdir = os.path.join(root, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
|
||
train_dataset = datasets.ImageFolder( | ||
traindir, | ||
transforms.Compose([ | ||
transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
val_dataset = datasets.ImageFolder( | ||
valdir, | ||
transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
return train_dataset, val_dataset | ||
|
||
def collate_fn(batch): | ||
images = torch.stack([b[0] for b in batch]) | ||
if trainer.fp16: | ||
images = images.half() | ||
labels = [b[1] for b in batch] | ||
labels = torch.tensor(labels).long() | ||
return {"images": images, "labels": labels} | ||
|
||
def top1_acc(pred, labels, **kwargs): | ||
pred = pred.argmax(dim=1) | ||
top1_acc = pred.eq(labels).sum().item() / len(pred) | ||
return top1_acc | ||
|
||
if __name__ == '__main__': | ||
|
||
print("building imagenet dataset......") | ||
train_dataset, val_dataset = build_dataset(root=data_path) | ||
print("training......") | ||
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) | ||
trainer.train(model, | ||
train_dataset=train_dataset, | ||
valid_dataset=val_dataset, | ||
collate_fn=collate_fn, | ||
metric_methods=[["top1_acc", top1_acc]], | ||
find_unused_parameters=False) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+149 KB
examples/swinv2/imagenet2012/val/n13044778/ILSVRC2012_val_00027938.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
import os | ||
from torchvision import transforms | ||
from torch.utils.data import Dataset, DataLoader | ||
import torchvision.datasets as datasets | ||
from tqdm import tqdm | ||
from flagai.auto_model.auto_loader import AutoLoader | ||
|
||
data_path = "./imagenet2012/" | ||
|
||
# swinv2 model_name support: | ||
# 1. swinv2-base-patch4-window16-256, | ||
# 2. swinv2-small-patch4-window16-256, | ||
# 3. swinv2-base-patch4-window8-256 | ||
loader = AutoLoader(task_name="classification", | ||
model_name="swinv2-small-patch4-window16-256") | ||
model = loader.get_model() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model.eval() | ||
model.to(device) | ||
|
||
# imagenet loader | ||
def data_loader(root, batch_size=256, workers=1): | ||
valdir = os.path.join(root, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
|
||
val_dataset = datasets.ImageFolder( | ||
valdir, | ||
transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
|
||
val_loader = DataLoader(val_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=workers | ||
) | ||
|
||
return val_loader | ||
|
||
@torch.no_grad() | ||
def test(model,data_loader): | ||
model.eval() | ||
top1_acc = 0.0 | ||
top5_acc = 0.0 | ||
|
||
for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
inputs, labels = inputs.to(device), labels.to(device) | ||
outputs = model(inputs)["logits"] | ||
_, top1_preds = outputs.max(1) | ||
top1_acc += top1_preds.eq(labels).sum().item() | ||
|
||
top5_pred = outputs.topk(5, 1, True)[1] | ||
top5_acc += top5_pred.eq(labels.view(-1, 1).expand_as(top5_pred).to(device)).sum().item() | ||
|
||
print( | ||
"test_top1_acc [{top1_acc}], test_top5_acc [{top5_acc}] \n".format( | ||
top1_acc=top1_acc/len(data_loader.dataset), | ||
top5_acc=top5_acc/len(data_loader.dataset), | ||
) | ||
) | ||
if __name__ == '__main__': | ||
|
||
val_loader = data_loader(data_path, batch_size=8, workers=8) | ||
test(model, val_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import torch | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
from flagai.trainer import Trainer | ||
from flagai.auto_model import AutoLoader | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
data_path = "./imagenet2012/" | ||
|
||
# use DDP for training by 4 gpus. | ||
trainer = Trainer(env_type="pytorchDDP", | ||
epochs=10, | ||
experiment_name="swinv2_imagenet_ddp", | ||
batch_size=32, | ||
weight_decay=1e-3, | ||
warm_up=0.1, | ||
lr=5e-5, | ||
save_interval=100, | ||
eval_interval=100, | ||
log_interval=10, | ||
num_gpus=4, | ||
hostfile="./hostfile", | ||
training_script="training_swinv2.py" | ||
) | ||
|
||
# swinv2 model_name support: | ||
# 1. swinv2-base-patch4-window16-256, | ||
# 2. swinv2-small-patch4-window16-256, | ||
# 3. swinv2-base-patch4-window8-256 | ||
loader = AutoLoader(task_name="classification", | ||
model_name="swinv2-base-patch4-window8-256", | ||
num_classes=1000) | ||
model = loader.get_model() | ||
|
||
# build imagenet dataset | ||
def build_dataset(root): | ||
traindir = os.path.join(root, 'train') | ||
valdir = os.path.join(root, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
|
||
train_dataset = datasets.ImageFolder( | ||
traindir, | ||
transforms.Compose([ | ||
transforms.RandomResizedCrop(256), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
val_dataset = datasets.ImageFolder( | ||
valdir, | ||
transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
) | ||
return train_dataset, val_dataset | ||
|
||
def collate_fn(batch): | ||
images = torch.stack([b[0] for b in batch]) | ||
if trainer.fp16: | ||
images = images.half() | ||
labels = [b[1] for b in batch] | ||
labels = torch.tensor(labels).long() | ||
return {"images": images, "labels": labels} | ||
|
||
def top1_acc(pred, labels, **kwargs): | ||
pred = pred.argmax(dim=1) | ||
top1_acc = pred.eq(labels).sum().item() / len(pred) | ||
return top1_acc | ||
|
||
if __name__ == '__main__': | ||
|
||
print("building imagenet dataset......") | ||
train_dataset, val_dataset = build_dataset(root=data_path) | ||
print("training......") | ||
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) | ||
trainer.train(model, | ||
train_dataset=train_dataset, | ||
valid_dataset=val_dataset, | ||
collate_fn=collate_fn, | ||
optimizer=optimizer, | ||
metric_methods=[["top1_acc", top1_acc]], | ||
find_unused_parameters=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
from .auto_loader import * |
Oops, something went wrong.