SFESS: Score Function Estimators for $k$ -Subset Sampling
Official PyTorch implementation of the ICLR 2025 paper
SFESS: Score Function Estimators for
Klas Wijk, Ricardo Vinuesa & Hossein Azizpour
https://openreview.net/forum?id=q87GUkdQBm
Implementations of gradient estimators for subset distributions:
- Gumbel softmax top-$k$ (GS) (Xie and Ermon 2019) [arXiv]
- Straight-through Gumbel softmax top-$k$ (STGS) (Xie and Ermon 2019) [arXiv]
- Implicit maximum likelihood estimation (I-MLE) (Niepert et al. 2021) [arXiv]
- SIMPLE (Ahmed et al. 2023) [arXiv]
- Score function estimators for
$k$ -subset sampling (SFESS) [this paper]
Multiple experiments:
- Feature selection
- Learning to explain (L2X)
- Subset VAE
- Stochastic k-nearest neighbors
numpy
matplotlib
seaborn
pytorch
torchvision
lightning
torchmetrics
To see the list of parameters for an experiment, run:
python main.py [task] --help
where [task]
is one of {l2x,vae,knn}
.
The toy experiment is found in /notebooks
.
This implementation extends code from:
- https://github.com/UCLA-StarAI/SIMPLE (toy experiment)
-
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html (stochastic
$k$ -nearest neighbors) - https://github.com/chendiqian/PR-MPNN and correspondence with Andrei-Marian Manolache, Ahmed Kareem, and Mathias Niepert (samplers)