Skip to content
forked from takerum/vat_tf

Virtual adversarial training with Tensorflow

License

Notifications You must be signed in to change notification settings

MarkOulitin/vat_tf

This branch is up to date with takerum/vat_tf:master.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

c5125d2 · Jul 2, 2018

History

36 Commits
Jan 29, 2018
Jul 2, 2018
Apr 11, 2017
Apr 12, 2017
Apr 12, 2017
Apr 12, 2017
Apr 13, 2017
Apr 12, 2017
Apr 13, 2017
Aug 5, 2017
Apr 17, 2017
Apr 13, 2017

Repository files navigation

vat_tf

Tensorflow implementation for reproducing the semi-supervised learning results on SVHN and CIFAR-10 dataset in the paper "Virtual Adversarial Training: a Regularization Method for Supervised and Semi-Supervised Learning" http://arxiv.org/abs/1704.03976

Requirements

tensorflow-gpu 1.1.0, scipy 0.19.0(for ZCA whitening)

Preparation of dataset for semi-supervised learning

On CIFAR-10

python cifar10.py --data_dir=./dataset/cifar10/

On SVHN

python svhn.py --data_dir=./dataset/svhn/

Semi-supervised Learning without augmentation

On CIFAR-10

python train_semisup.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir=./log/cifar10/ --num_epochs=500 --epoch_decay_start=460 --epsilon=10.0 --method=vat

On SVHN

python train_semisup.py --dataset=svhn --data_dir=./dataset/svhn/ --log_dir=./log/svhn/ --num_epochs=120 --epoch_decay_start=80 --epsilon=2.5 --top_bn --method=vat

Semi-supervised Learning with augmentation

On CIFAR-10

python train_semisup.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir=./log/cifar10aug/ --num_epochs=500 --epoch_decay_start=460 --aug_flip=True --aug_trans=True --epsilon=8.0 --method=vat

On SVHN

python train_semisup.py --dataset=svhn --data_dir=./dataset/svhn/ --log_dir=./log/svhnaug/ --num_epochs=120 --epoch_decay_start=80 --epsilon=3.5 --aug_trans=True --top_bn --method=vat

Semi-supervised Learning with augmentation + entropy minimization

On CIFAR-10

python train_semisup.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir=./log/cifar10aug/ --num_epochs=500 --epoch_decay_start=460 --aug_flip=True --aug_trans=True --epsilon=8.0 --method=vatent

On SVHN

python train_semisup.py --dataset=svhn --data_dir=./dataset/svhn/ --log_dir=./log/svhnaug/ --num_epochs=120 --epoch_decay_start=80 --epsilon=3.5 --aug_trans=True --top_bn --method=vatent

Evaluation of the trained model

On CIFAR-10

python test.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir=<path_to_log_dir>

On SVHN

python test.py --dataset=svhn --data_dir=./dataset/svhn/ --log_dir=<path_to_log_dir> --top_bn

About

Virtual adversarial training with Tensorflow

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%