This repository is the code for comparing the performance of contrastive vs regulariser (nonb-contrastive) self-supervised learning of structure world models(Thomas Kipf, Elise van der Pol, Max Welling).
This is based on the original implementation of C-SWM. The only filed that is modified is modules.py
.
- Python 3.6 or 3.7
- PyTorch version 1.2
- OpenAI Gym version: 0.12.0
pip install gym==0.12.0
- OpenAI Atari_py version: 0.1.4:
pip install atari-py==0.1.4
- Scikit-image version 0.15.0
pip install scikit-image==0.15.0
- Matplotlib version 3.0.2
pip install matplotlib==3.0.2
2D Shapes:
python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2
Atari Pong:
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2
You need to pass the type of self-supervised loss function as an argument. Currently, the options are contrastive
and vic
.
2D Shapes:
python train.py --dataset data/shapes_train.h5 --encoder small --name shapes --ssl-loss vic
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes_vic --num-steps 1
Atari Pong:
python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong --ssl-loss vic
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong_vic --num-steps 1
Loss | H@1 | MRR |
---|---|---|
contrastive |
99 |
99 |
VICreg |
99 |
99 |
Loss | H@1 | MRR |
---|---|---|
contrastive | 39 |
57 |
VICreg |
46.8 |
62 |