forked from samarth4149/PAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit dded25a
Showing
20 changed files
with
1,962 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz | ||
tar -zxvf CUB_200_2011.tgz | ||
python write_CUB_filelist.py | ||
python get_attributes.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import csv | ||
import json | ||
import os | ||
|
||
import torch | ||
|
||
cwd = os.getcwd() | ||
img_name_file = os.path.join(cwd,'CUB_200_2011/images.txt') | ||
|
||
name_idx_map = {} | ||
with open(img_name_file, 'r') as f: | ||
rd = csv.reader(f, delimiter=' ') | ||
for row in rd: | ||
name_idx_map[row[1]] = int(row[0])-1 | ||
|
||
num_imgs = len(name_idx_map) | ||
num_attr = 312 | ||
attr_labels = torch.zeros(num_imgs, num_attr) | ||
certainty_mask = torch.zeros(num_imgs, num_attr) | ||
conf_scores = torch.zeros(num_imgs, num_attr) | ||
attr_file = os.path.join( | ||
cwd,'CUB_200_2011/attributes/image_attribute_labels.txt') | ||
|
||
with open(attr_file, 'r') as f: | ||
rd = csv.reader(f, delimiter=' ') | ||
for row in rd: | ||
img_idx = int(row[0])-1 | ||
attr_idx = int(row[1])-1 | ||
conf = int(row[3]) | ||
attr_labels[img_idx, attr_idx] = int(row[2]) | ||
conf_scores[img_idx, attr_idx] = conf | ||
if conf > 2.: | ||
certainty_mask[img_idx, attr_idx] = 1 | ||
|
||
print('Getting image attributes...') | ||
for split in ['base', 'novel', 'val']: | ||
print('Processing {}'.format(split)) | ||
with open('{}.json'.format(split), 'r') as f: | ||
meta = json.load(f) | ||
img_idxs = [] | ||
for img_name in meta['image_names']: | ||
curr_idx = name_idx_map['/'.join(img_name.split('/')[-2:])] | ||
img_idxs.append(curr_idx) | ||
img_idxs = torch.LongTensor(img_idxs) | ||
torch.save({ | ||
'attr_labels' : attr_labels[img_idxs], | ||
'certainty_mask' : certainty_mask[img_idxs], | ||
'conf_scores' : conf_scores[img_idxs] | ||
}, '{}_attr_test.pt'.format(split)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import random | ||
from os import listdir | ||
from os.path import isfile, isdir, join | ||
|
||
import numpy as np | ||
|
||
cwd = os.getcwd() | ||
data_path = join(cwd,'CUB_200_2011/images') | ||
savedir = './' | ||
dataset_list = ['base','val','novel'] | ||
|
||
|
||
folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] | ||
folder_list.sort() | ||
label_dict = dict(zip(folder_list,range(0,len(folder_list)))) | ||
|
||
classfile_list_all = [] | ||
|
||
for i, folder in enumerate(folder_list): | ||
folder_path = join(data_path, folder) | ||
classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) | ||
random.shuffle(classfile_list_all[i]) | ||
|
||
print('Writing filelist...') | ||
for dataset in dataset_list: | ||
file_list = [] | ||
label_list = [] | ||
for i, classfile_list in enumerate(classfile_list_all): | ||
if 'base' in dataset: | ||
if (i%2 == 0): | ||
file_list = file_list + classfile_list | ||
label_list = label_list + np.repeat(i, len(classfile_list)).tolist() | ||
if 'val' in dataset: | ||
if (i%4 == 1): | ||
file_list = file_list + classfile_list | ||
label_list = label_list + np.repeat(i, len(classfile_list)).tolist() | ||
if 'novel' in dataset: | ||
if (i%4 == 3): | ||
file_list = file_list + classfile_list | ||
label_list = label_list + np.repeat(i, len(classfile_list)).tolist() | ||
|
||
fo = open(savedir + dataset + ".json", "w") | ||
fo.write('{"label_names": [') | ||
fo.writelines(['"%s",' % item for item in folder_list]) | ||
fo.seek(0, os.SEEK_END) | ||
fo.seek(fo.tell()-1, os.SEEK_SET) | ||
fo.write('],') | ||
|
||
fo.write('"image_names": [') | ||
fo.writelines(['"%s",' % item for item in file_list]) | ||
fo.seek(0, os.SEEK_END) | ||
fo.seek(fo.tell()-1, os.SEEK_SET) | ||
fo.write('],') | ||
|
||
fo.write('"image_labels": [') | ||
fo.writelines(['%d,' % item for item in label_list]) | ||
fo.seek(0, os.SEEK_END) | ||
fo.seek(fo.tell()-1, os.SEEK_SET) | ||
fo.write(']}') | ||
|
||
fo.close() | ||
print("%s -OK" %dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# Few shot learning using PAN | ||
## Requirements | ||
To install the requirements for this project. (Virtual environment encouraged, e.g. conda) | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Data | ||
Copy the directory `CUB_filelists` to the location where you would like to download | ||
CUB data (\~1.2 GB) | ||
|
||
``` | ||
cd CUB_filelists | ||
bash download_CUB.sh | ||
``` | ||
|
||
Then copy the files `base.json`, `val.json` and `novel.json` to a directory named `data` in the base folder of the project. | ||
|
||
## To run the siamese model | ||
The file main_siamese.py can be used to train the siamese model. | ||
``` | ||
usage: main_siamese.py [-h] [--eval_only] [--test_split] [--n_way N_WAY] | ||
[--n_support N_SUPPORT] [--n_query N_QUERY] | ||
[--num_val_ep NUM_VAL_EP] [--num_epoch NUM_EPOCH] | ||
[--data_aug] [-b BATCH_SIZE] [--margin MARGIN] | ||
[--resume PATH] [--lr LR] [--attr_pred] [--lam LAM] | ||
[--fix_val] [--save_dir SAVE_DIR] [--log_step LOG_STEP] | ||
Siamese Baseline for Few shot learning | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--eval_only Do only evaluation (presumably using a loaded model) | ||
--test_split Use the test split for validation | ||
--n_way N_WAY class num to classify for testing (validation) | ||
--n_support N_SUPPORT | ||
number of labeled data in each class | ||
--n_query N_QUERY number of query examples in each class | ||
--num_val_ep NUM_VAL_EP | ||
number of episodes for validation | ||
--num_epoch NUM_EPOCH | ||
Number of epochs for training | ||
--data_aug Train with data augmentation | ||
-b BATCH_SIZE, --batch_size BATCH_SIZE | ||
Minibatch size used | ||
--margin MARGIN Margin to use for triplet loss | ||
--resume PATH path to latest checkpoint (default: none) | ||
--lr LR, --learning-rate LR | ||
initial learning rate | ||
--attr_pred Use an auxiliary attribute predictor branch | ||
--lam LAM Weight of attribute loss in the final loss func | ||
--fix_val Each time validation is done, the same classes are | ||
chosen. Numpy random seed is fixed. | ||
--save_dir SAVE_DIR path to directory for saving results | ||
--log_step LOG_STEP log each <log_step> epoch | ||
``` | ||
Command to run training: | ||
``` | ||
python main_siamese.py --save_dir path/to/results --num_epoch 200 --data_aug | ||
``` | ||
Command to run inference: | ||
``` | ||
python main_siamese.py --resume path/to/saved_model --eval_only --test_split --num_val_ep 600 | ||
``` | ||
|
||
## Extract features from the siamese model | ||
Run | ||
``` | ||
for split in base val novel; | ||
do | ||
python extract_features.py --model path/to/model --split $split; | ||
done; | ||
``` | ||
where `path/to/model` is the path of the model file to use (saved from training the siamese model). | ||
|
||
## Extract attribute labels for CUB images | ||
Inside the directory `CUB_filelist`, run | ||
``` | ||
for split in base val novel; | ||
do | ||
python get_attributes.py --split $split; | ||
done; | ||
``` | ||
|
||
## To run the PAN model | ||
To run the pan model, | ||
``` | ||
usage: main_pan.py [-h] [--eval_only] [--test_split] [--n_way N_WAY] | ||
[--n_support N_SUPPORT] [--n_query N_QUERY] | ||
[--num_val_ep NUM_VAL_EP] [--num_epoch NUM_EPOCH] | ||
[--resume PATH] [--lr LR] [--fix_val] | ||
[-hi HIDDEN [HIDDEN ...]] [-deg DEGREE] [-do DROPOUT] | ||
[-sup_do SUPPORT_DROPOUT] [--hybrid] [--use_at_lab] | ||
[--label_func [{OR,AND,XNOR,AND_XOR}]] | ||
[--num_sup_at NUM_SUP_AT] [--nout NOUT] [--lam LAM] | ||
[--no_ge] [--save_dir SAVE_DIR] [--log_step LOG_STEP] | ||
PAN model for few shot learning | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--eval_only Do only evaluation (presumably using a loaded model) | ||
--test_split Use the test split for validation | ||
--n_way N_WAY class num to classify for testing (validation) | ||
--n_support N_SUPPORT | ||
number of labeled data in each class | ||
--n_query N_QUERY number of query examples in each class | ||
--num_val_ep NUM_VAL_EP | ||
number of episodes for validation | ||
--num_epoch NUM_EPOCH | ||
Number of epochs for training | ||
--resume PATH path to latest checkpoint (default: none) | ||
--lr LR, --learning-rate LR | ||
initial learning rate | ||
--fix_val Each time validation is done, the same classes are | ||
chosen. Numpy random seed is fixed. | ||
-hi HIDDEN [HIDDEN ...], --hidden HIDDEN [HIDDEN ...] | ||
Number of hidden units in the GCN layers. | ||
-deg DEGREE, --degree DEGREE | ||
Degree of the convolution (Number of supports) | ||
-do DROPOUT, --dropout DROPOUT | ||
Dropout fraction | ||
-sup_do SUPPORT_DROPOUT, --support_dropout SUPPORT_DROPOUT | ||
Use dropout on the support matrices, dropping all the | ||
connections from some nodes | ||
--hybrid Whether to a hybrid model. In this case nout is set to | ||
number of attributes + nout and use_at_lab is set to | ||
True | ||
--use_at_lab Whether to use attribute labels. In this case nout is | ||
set to number of attributes | ||
--label_func [{OR,AND,XNOR,AND_XOR}] | ||
Logical function to combine attribute labels | ||
--num_sup_at NUM_SUP_AT | ||
Use only the first num_sup_at attribute labels | ||
--nout NOUT Number of outputs of the MLP decoder, typically same | ||
as number of attributes | ||
--lam LAM Weight of attribute loss in the final loss func | ||
--no_ge Whether to use graph image encoder | ||
--save_dir SAVE_DIR path to directory for saving results | ||
--log_step LOG_STEP log each <log_step> epoch | ||
``` | ||
|
||
### Performance Comparison | ||
|
||
| Method | 5-way 5-shot accuracy | | ||
|------------------|-----------------------| | ||
| Baseline++ | 83.58 | | ||
| ProtoNet | 87.42 | | ||
| Trinet | 84.10 | | ||
| TEAM | 87.17 | | ||
| CGAE | 88.00 $\pm$ 1.13 | | ||
| PAN-Unsupervised | 92.69 $\pm$ 0.28 | | ||
| PAN-Supervised | 92.77 $\pm$ 0.30 | | ||
|
||
Commands for running above models (runnable on our code): | ||
- CGAE : `python main_pan.py --nout=1 --fix_val --save_dir=path/to/results` | ||
- PAN-Unsupervised : `python main_pan.py --no_ge --nout=50 --fix_val --save_dir=path/to/results` | ||
- PAN-Hybrid : `python main_pan.py --hybrid --no_ge --nout=10 --num_sup_at=10 --lam=1e-5 --fix_val --save_dir=path/to/results` | ||
- PAN-Supervised : `python main_pan.py --use_at_lab --no_ge --fix_val --lam=1e-5 --save_dir=path/to/results` | ||
|
||
Commands for running inference are similar to that for the Siamese Network (using `eval_only` and `test_split` options) | ||
|
||
Options for number of similarity conditions for different model variants : | ||
- For the unsupervised model, set `NOUT` to the number of similarity conditions. | ||
- For a supervised model, set `use_at_lab` to `True` and set `num_sup_at` to the number of attributes to use. | ||
- For a hybrid model, set `hybrid` to `True`. `NOUT` in this case would be the number of unsupervised similarity conditions. | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from utils.misc import get_degree_supports | ||
from utils.misc import process_supports | ||
import numpy as np | ||
import torch | ||
|
||
class ValDataset(torch.utils.data.Dataset): | ||
def __init__(self, args, features, img_labels, | ||
num_ep, n_way, n_support, n_query): | ||
# Note that n_support in the arguments is the number of support examples | ||
super(ValDataset, self).__init__() | ||
self.features = features | ||
self.img_labels = img_labels | ||
self.num_ep = num_ep | ||
self.n_way = n_way | ||
self.n_support = n_support | ||
self.n_query = n_query | ||
self.idxs = np.arange(len(self.img_labels)) | ||
self.classes = np.unique(self.img_labels) | ||
|
||
def __len__(self): | ||
return self.num_ep | ||
|
||
def __getitem__(self, idx): | ||
# first support examples and then query examples | ||
classes = self.classes[torch.randperm(len(self.classes))[:self.n_way]] | ||
examples = [] | ||
pos_idxs = np.array([(i, j) for i in range(self.n_support) | ||
for j in range(i + 1, self.n_support)]) | ||
all_pos_idxs = [] | ||
for i, cl in enumerate(classes): | ||
curr_idxs = self.idxs[self.img_labels == cl] | ||
curr_idxs = curr_idxs[torch.randperm( | ||
len(curr_idxs))[:(self.n_query + self.n_support)]] | ||
examples.append(self.features[curr_idxs]) | ||
# add edge pairs for current class | ||
all_pos_idxs.append(pos_idxs + i*(self.n_query + self.n_support)) | ||
|
||
|
||
# examples : (n_way * (n_support + n_query)) x feat_size | ||
examples = torch.cat(examples, axis=0) | ||
all_pos_idxs = np.concatenate(all_pos_idxs) | ||
|
||
query_idxs = [] | ||
for i in range(len(classes)): | ||
for j in range(self.n_support, self.n_query + self.n_support): | ||
for k in range(len(classes)): | ||
for l in range(self.n_support): | ||
query_idxs.append((i*(self.n_query+self.n_support)+j, | ||
k*(self.n_query+self.n_support)+l)) | ||
|
||
query_idxs = np.array(query_idxs) | ||
|
||
return examples, all_pos_idxs, query_idxs |
Oops, something went wrong.