This repository has the implementation for both Prototypical Networks, proposed by Snell et al. in 2017, and Prototypical Networks with Random Weights, proposed by the owner of this repository in 2021.
The repository also has scripts to train these models for the task of few-shot image classification on Omniglot and mini-ImageNet.
It's made with Python3 and tested on Linux.
Clone the repository or download the compressed source code. If you opted for the latter, you need to extract the source code to a desired directory.
In both cases, open the project directory in your terminal.
Now, install the requirements. You can achieve that by running:
pip3 install -r requirements.txt
In case you can't install the requirements as a user, run the following instead:
sudo pip3 install -r requirements.txt
You also need to install the protonets package with:
pip3 install -e .
You may need to install it with sudo:
sudo pip3 install -e .
After installing the requirements and the package, you're ready to go.
You can train two models:
- Prototypical Networks;
- Prototypical Networks with Random Weights.
And there are two available datasets:
- Omniglot;
- mini-ImageNet.
First, you need to go to the scripts directory.
Once you're in this directory, you need to download the datasets.
The dataset_downloader.py script takes a -d/--dataset argument. If you try to execute it without passing the required argument, you should expect to see the following message:
usage: dataset_downloader.py [-h] -d {all,omniglot,mini_imagenet}
dataset_downloader.py: error: the following arguments are required: -d/--dataset
Reading the output above we know that there are three possible choices: all, omniglot and mini_imagenet.
As an example, let's suppose we only want to download omniglot:
python3 dataset_downloader.py -d omniglot
After the download is complete, we can train a model on omniglot.
The training.py script takes two arguments: -m/--model and -d/--dataset. If you run it without passing the required arguments, you should expect to see the following message:
usage: training.py [-h] -m {vanilla,random_weights} -d {omniglot,mini_imagenet}
training.py: error: the following arguments are required: -m/--model, -d/--dataset
Reading the output above we know that both arguments have two possible values. For the first one, these values are: vanilla and random_weights. As for the latter, the values are: omniglot and mini_imagenet.
Since we have downloaded omniglot, let's run:
python3 training.py -m vanilla -d omniglot
After the training is complete, we can retrain by running:
python3 retraining.py
And after retraining, we can evaluate our model with:
python3 evaluation.py
The results are be stored in a directory called results.
Bear in mind that you have to rename or delete the results directory before training another model.
The retraining and the evaluation scripts work with the model obtained when you first execute the training script.
You can find the few-shot setup and other parameters in the config directory.
The splits and the implementation follow the procedure of Prototypical Networks For Few-shot Learning.
The results obtained with this implementation are comparable to those obtained with the original one.
You can check my execution logs and trained models here.
This project was based on:
- Cyprien Nielly implementation of Prototypical Networks.
- The original implementation, which can be found in Jake Snell's Github.
The idea of PNs can be originally found in Prototypical Networks for Few-shot Learning.
It's worth mentioning that using weights in order to calculate the prototypes is an idea that can be found in the paper Improved Prototypical Networks for Few-Shot Learning.