Skip to content

(元学习)Very simple pytorch maml implement

Notifications You must be signed in to change notification settings

hanchan11/torch_maml

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pyorch- MAML

Part 1. Introduction

As we all know, deep learning need vast data. If you don't have this condition, you can use pre-training weights. Most of data can be fitted be pre-training weights, but there all still some data that can't converge to the global lowest point. So it is exist one weights that can let all task get best result?

Yes, this is "Model-Agnostic Meta-Learning". The biggest difference between MAML and pre-training weights:Pre-training weights minimize only for original task loss. MAML can minimize all task loss with a few steps of training.

Part 2. Quick Start

  1. Pull repository.
git clone https://github.com/Runist/torch_maml.git
  1. You need to install some dependency package.
cd torch_maml
pip installl -r requirements.txt
  1. Download the Omiglot dataset.
mkdir data
cd data
wget https://github.com/Runist/MAML-keras/releases/download/v1.0/Omniglot.tar
tar -xvf Omniglot.tar
  1. Start train.
python train.py
epoch 1: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.45s/it, loss=1.2326]
=> loss: 1.2917   acc: 0.4990   val_loss: 0.8875   val_acc: 0.7963
epoch 2: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.32s/it, loss=0.9818]
=> loss: 1.0714   acc: 0.6688   val_loss: 0.8573   val_acc: 0.7713
epoch 3: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.34s/it, loss=0.9472]
=> loss: 0.9896   acc: 0.6922   val_loss: 0.8000   val_acc: 0.7773
epoch 4: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.39s/it, loss=0.7929]
=> loss: 0.8258   acc: 0.7812   val_loss: 0.8071   val_acc: 0.7676
epoch 5: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.14s/it, loss=0.6662]
=> loss: 0.7754   acc: 0.7646   val_loss: 0.7144   val_acc: 0.7833

Part 3. Train your own dataset

  1. You should set same parameters in args.py. More detail you can get in my blog.
parser.add_argument('--train_data_dir', type=str,
                    default="./data/Omniglot/images_background/",
                    help='The directory containing the train image data.')
parser.add_argument('--val_data_dir', type=str,
                    default="./data/Omniglot/images_evaluation/",
                    help='The directory containing the validation image data.')
parser.add_argument('--n_way', type=int, default=10,
                    help='The number of class of every task.')
parser.add_argument('--k_shot', type=int, default=1,
                    help='The number of support set image for every task.')
parser.add_argument('--q_query', type=int, default=1,
                    help='The number of query set image for every task.')
  1. Start training.
python train.py --n_way=5 --k_shot=1 --q_query=1

Part 4. Paper and other implement

About

(元学习)Very simple pytorch maml implement

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%