Skip to content

Implement MLP-Mixer: An all-MLP Architecture for Vision paper with Pytorch

Notifications You must be signed in to change notification settings

morthannn/mlp-mixer-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLP Mixer

This repo focus implement the paper with Pytorch MLP-Mixer: An all-MLP Architecture for Vision.

Documentation

Install
cd mlp-mixer-pytorch
pip install -r requirements.txt  # install requirements
Training

Data structure:

mlp-mixer-pytorch
├── data/
│   ├── train/
│   │   ├── class_a/
│   │   │   ├── a_image_1.jpg
│   │   │   ├── a_image_2.jpg
│   │   │   └── a_image_3.jpg
│   │   ├── class_b/
│   │   │   ├── b_image_1.jpg
│   │   │   ├── b_image_2.jpg
│   │   │   └── b_image_3.jpg
│   │   └── class_c/
│   │       ├── c_image_1.jpg
│   │       ├── c_image_2.jpg
│   │       └── c_image_3.jpg
│   └── valid/
│       ├── class_a/
│       │   ├── a_image_1.jpg
│       │   ├── a_image_2.jpg
│       │   └── a_image_3.jpg
│       ├── class_b/
│       │   ├── b_image_1.jpg
│       │   ├── b_image_2.jpg
│       │   └── b_image_3.jpg
│       └── class_c/
│           ├── c_image_1.jpg
│           ├── c_image_2.jpg
│           └── c_image_3.jpg
└── train.py
python train.py --epochs 300 --learning-rate 1e3 --batch-size 128 --image-size 300 --patch-size 100 --num-mlp-blocks 8 --projection-dim 512 --token-mixing-dim 2048 --channel-mixing-dim 256 --num-workers 1 --device cuda:0                                                      
Inference
import torch
from PIL import Image
from torchvision import transforms
from model.mlp_mixer import MLPMixer

# Model
model = MLPMixer(
  num_classes=2,
  image_size=(300, 300),
  patch_size=100,
  num_mlp_blocks=8,
  projection_dim=512, 
  token_mixing_dim=2048,
  channel_mixing_dim=256
)

model.load_state_dict(torch.load('runs/exp_*/last.pt'))

# Image
image_path = "name_image.jpg"
image = Image.open(image_path)

# Transforms
transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor()
])

image = transform(image)

# Inference
logis = model(image.unsqueeze(0))

# Results
results = logis.argmax(dim=1)

About

Implement MLP-Mixer: An all-MLP Architecture for Vision paper with Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages