⭐ Star us on GitHub — it helps!!
PyTorch implementation for Zero-Shot Knowledge Distillation in Deep Networks
You will need a machine with a GPU and CUDA installed.
Then, you prepare runtime environment:
pip install -r requirements.txt
For mnist dataset,
python main.py --dataset=mnist --t_train=False --num_sample=12000 --batch_size=200
For cifar10 dataset,
python main.py --dataset=cifar10 --t_train=False --num_sample=24000 --batch_size=100
Arguments:
dataset
- available dataset: ['mnist', 'cifar10', 'cifar100']t_train
- Train teacher network??- if True, train teacher network
- elif False, load trained teacher network
num_sample
- Number of DIs crafted per categorybeta
- Beta scaling vectorsbatch_size
- batch sizelr
- learning rateiters
- iteration numbers_save_path
- save path for student networkdo_genimgs
- generate synthesized images from ZSKD??- if True, generate images
- elif False, you must have the synthesized images that are generated from ZSKD
✅ Check my blog!! Here