Skip to content

Latest commit

 

History

History
113 lines (60 loc) · 2.74 KB

README.md

File metadata and controls

113 lines (60 loc) · 2.74 KB

Infomin Representation Learning


This repository provides a PyTorch implementation of the paper "Scalable Infomin Learning", NeurIPS 2022.

Introduction

We consider learning representation with the following objective: $$\min L(f(X), Y) + \beta \cdot I(f(X); T)$$ where $L$ is some loss (e.g. classification loss) and $I$ is the mutual information. This objective is ubiquitous in fairness, disentangled representation learning, domain adaptation, invariance, etc.

We show that to minimise $I(f(X); T)$ above, we really need not to estimate it, which could be challenging. Rather, we can simply consider a random 'slice' of $I(f(X); T)$ during mini-batch learning, which is much easier to estimate.

See also the materials: Poster, Slides, Demo. The demo is a minimalist jupyter notebook for trying our method.

Prerequisite

1. Libraries

  • Python 3.5+
  • Pytorch 1.12.1
  • Torchvision 0.13.1
  • Numpy, scipy, matplotlib

We strongly recommend to use conda to manage/update library dependence:

conda install pytorch torchvision matplotlib

2. Data

Please run the following script to download the PIE dataset (contributed by https://github.com/bluer555/CR-GAN)

bash scripts/download_pie.sh

For fairness experiments, the data is in the /data folder.

MI estimators/independence tests

at /mi

  • Pearson Correlation
  • Distance Correlation
  • Neural Total Correlation
  • Neural Renyi Correlation
  • CLUB
  • Sliced mutual information

These estimators are used to quantify $I(f(X); T)$.

Applications

at /tasks

  • Fairness
  • Disentangled representation learning
  • Domain adaptation

Different tasks are isolated to each other.

Citation

If you find our paper / repository helpful, please consider citing:

@article{chen2023scalable,
  title={Scalable Infomin Learning},
  author={Chen, Yanzhi and Sun, Weihao and Li, Yingzhen and Weller, Adrian},
  journal={arXiv preprint arXiv:2302.10701},
  year={2023}
}

Results