As we all know, deep learning need vast data. If you don't have this condition, you can use pre-training weights. Most of data can be fitted be pre-training weights, but there all still some data that can't converge to the global lowest point. So it is exist one weights that can let all task get best result?
Yes, this is "Model-Agnostic Meta-Learning". The biggest difference between MAML and pre-training weights:Pre-training weights minimize only for original task loss. MAML can minimize all task loss with a few steps of training.
- Pull repository.
git clone https://github.com/Runist/torch_maml.git
- You need to install some dependency package.
cd torch_maml
pip installl -r requirements.txt
- Download the Omiglot dataset.
mkdir data
cd data
wget https://github.com/Runist/MAML-keras/releases/download/v1.0/Omniglot.tar
tar -xvf Omniglot.tar
- Start train.
python train.py
epoch 1: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.45s/it, loss=1.2326]
=> loss: 1.2917 acc: 0.4990 val_loss: 0.8875 val_acc: 0.7963
epoch 2: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.32s/it, loss=0.9818]
=> loss: 1.0714 acc: 0.6688 val_loss: 0.8573 val_acc: 0.7713
epoch 3: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.34s/it, loss=0.9472]
=> loss: 0.9896 acc: 0.6922 val_loss: 0.8000 val_acc: 0.7773
epoch 4: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.39s/it, loss=0.7929]
=> loss: 0.8258 acc: 0.7812 val_loss: 0.8071 val_acc: 0.7676
epoch 5: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00, 2.14s/it, loss=0.6662]
=> loss: 0.7754 acc: 0.7646 val_loss: 0.7144 val_acc: 0.7833
- You should set same parameters in args.py. More detail you can get in my blog.
parser.add_argument('--train_data_dir', type=str,
default="./data/Omniglot/images_background/",
help='The directory containing the train image data.')
parser.add_argument('--val_data_dir', type=str,
default="./data/Omniglot/images_evaluation/",
help='The directory containing the validation image data.')
parser.add_argument('--n_way', type=int, default=10,
help='The number of class of every task.')
parser.add_argument('--k_shot', type=int, default=1,
help='The number of support set image for every task.')
parser.add_argument('--q_query', type=int, default=1,
help='The number of query set image for every task.')
- Start training.
python train.py --n_way=5 --k_shot=1 --q_query=1