Skip to content

Commit

Permalink
[Example] add latent dirichlet allocation (dmlc#2883)
Browse files Browse the repository at this point in the history
* add lda model

* tweak latent dirichlet allocation

* Update README.md

* Update README.md

* update example index

* update header

* minor tweak

* add example test

* update doc

* Update README.md

* Update README.md

* add partial_fit for free

* Update examples/pytorch/lda/lda_model.py

Co-authored-by: Quan (Andy) Gan <[email protected]>

* Update examples/pytorch/lda/example_20newsgroups.py

Co-authored-by: Quan (Andy) Gan <[email protected]>

* Update lda_model.py

* bugfix torch Gamma uses rate parameter

Co-authored-by: Yifei Ma <[email protected]>
Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
3 people authored May 17, 2021
1 parent 657c220 commit c018436
Show file tree
Hide file tree
Showing 5 changed files with 449 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The folder contains example implementations of selected research papers related

| Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB |
| ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ |
| [Latent Dirichlet Allocation](#lda) | :heavy_check_mark: | :heavy_check_mark: | | | |
| [Network Embedding with Completely-imbalanced Labels](#rect) | :heavy_check_mark: | | | | |
| [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | |
| [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | |
Expand Down Expand Up @@ -410,6 +411,12 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)
- Tags: knowledge graph embedding

## 2010

- <a name="lda"></a> Hoffman et al. Online Learning for Latent Dirichlet Allocation. [Paper link](https://papers.nips.cc/paper/2010/file/71f6278d140af599e06ad9bf1ba03cb0-Paper.pdf).
- Example code: [PyTorch](../examples/pytorch/lda)
- Tags: sklearn, decomposition, latent Dirichlet allocation

## 2009

- <a name="astar"></a> Riesen et al. Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic. [Paper link](https://core.ac.uk/download/pdf/33054885.pdf).
Expand Down
76 changes: 76 additions & 0 deletions examples/pytorch/lda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
Latent Dirichlet Allocation
===
LDA is a classical algorithm for probabilistic graphical models. It assumes
hierarchical Bayes models with discrete variables on sparse doc/word graphs.
This example shows how it can be done on DGL,
where the corpus is represented as a bipartite multi-graph G.
There is no back-propagation, because gradient descent is typically considered
inefficient on probability simplex.
On the provided small-scale example on 20 news groups dataset, our DGL-LDA model runs
50% faster on GPU than sklearn model without joblib parallel.

Key equations
---

* The corpus is generated by hierarchical Bayes: document(d) -> latent topic(z) -> word(w)
* All positions in the same document have shared topic distribution θ_d~Dir(α)
* All positions of the same topic have shared word distribution β_z~Dir(η)
* The words in the same document / topic are correlated.

**MAP**

A simplified MAP model is just a non-conjugate model with an inner summation to integrate out the latent topic variable:
<img src="https://latex.codecogs.com/gif.latex?p(G)=\prod_{(d,w)}\left(\sum_z\theta_{dz}\beta_{zw}\right)" title="map" />

The main complications are that θ_d / β_z are shared in the same document / topic and the variables reside in a probability simplex.
One way to work around it is via expectation maximization

<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)&space;=\sum_{(d,w)}\log\left(\sum_z\theta_{dz}\beta_{zw}\right)&space;\geq\sum_{(d,w)}\mathbb{E}_q\log\left(\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}\right)" title="map-em" />

* An explicit posterior is ϕ_dwz ∝ θ_dz * β_zw
* E-step: find summary statistics with fractional membership
* M-step: set θ_d, β_z proportional to the summary statistics
* With an explicit posterior, the bound is tight.

**Variational Bayes**

A Bayesian model adds Dirichlet priors to θ_d & β_z. This causes the posterior to be implicit and the bound to be loose. We will still use an independence assumption and cycle through the variational parameters similarly to coordinate ascent.

* The evidence lower-bound is
<img src="https://latex.codecogs.com/gif.latex?\log&space;p(G)\geq&space;\mathbb{E}_q\left[\sum_{(d,w)}\log\left(&space;\frac{\theta_{dz}\beta_{zw}}{q(z;\phi_{dw})}&space;\right)&space;&plus;\sum_{d}&space;\log\left(&space;\frac{p(\theta_d;\alpha)}{q(\theta_d;\gamma_d)}&space;\right)&space;&plus;\sum_{z}&space;\log\left(&space;\frac{p(\beta_z;\eta)}{q(\beta_z;\lambda_z)}&space;\right)\right]" title="elbo" />

* ELBO objective function factors as
<img src="https://latex.codecogs.com/gif.latex?\sum_{(d,w)}&space;\phi_{dw}^{\top}\left(&space;\mathbb{E}_{\gamma_d}[\log\theta_d]&space;&plus;\mathbb{E}_{\lambda}[\log\beta_{:w}]&space;-\log\phi_{dw}&space;\right)&space;\\&space;&plus;&space;\sum_d&space;(\alpha-\gamma_d)^\top\mathbb{E}_{\gamma_d}[\log&space;\theta_d]-(\log&space;B(\alpha)-\log&space;B(\gamma_d))&space;\\&space;&plus;&space;\sum_z&space;(\eta-\lambda_z)^\top\mathbb{E}_{\lambda_z}[\log&space;\beta_z]-(\log&space;B(\eta)-\log&space;B(\lambda_z))" title="factors" />

* Similarly, optimization alternates between ϕ, γ, λ. Since θ, β are random, we use an explicit solution for E[log X] under Dirichlet distribution via digamma function.

DGL usage
---
The corpus is represented as a bipartite multi-graph G.
We use DGL to propagate information through the edges and aggregate the distributions at doc/word nodes.
For scalability, the phi variables are transient and updated during message passing.
The gamma / lambda variables are updated after the nodes receive all edge messages.
Following the conventions in [1], the gamma update is called E-step and the lambda update is called M-step, because the beta variable has smaller variance.
The lambda variable is further recorded by the trainer and we may further approximate its MAP estimate by using a large step size for word nodes.
A separate function is used to produce perplexity, which is based on the ELBO objective function divided by the total numbers of word/doc occurrences.

Example
---
`%run example_20newsgroups.py`
* Approximately matches scikit-learn training perplexity after 10 rounds of training.
* Exactly matches scikit-learn training perplexity if word_z is set to lda.components_.T
* To compute testing perplexity, we need to fix the word beta variables via MAP estimate. This step is not taken by sklearn and its beta part seems to contain another bug by dividing the training loss by the testing word counts. Nonetheless, I recommend setting `step_size["word"]` to a larger value to approximate the corresponding MAP estimate.
* The DGL-LDA model runs 50% faster on GPU devices compared with sklearn without joblib parallel.

Advanced configurations
---
* Set `step_size["word"]` to a large value obtain a MAP estimate for beta.
* Set `0<word_rho<1` for online learning.

References
---

1. Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
Dirichlet Allocation. Advances in Neural Information Processing Systems 23
(NIPS 2010).
2. Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
135 changes: 135 additions & 0 deletions examples/pytorch/lda/example_20newsgroups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2021 Yifei Ma
# Modified from scikit-learn example "plot_topics_extraction_with_nmf_lda.py"
# with the following original authors with BSD 3-Clause:
# * Olivier Grisel <[email protected]>
# * Lars Buitinck
# * Chyi-Kwei Yau <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from time import time
import matplotlib.pyplot as plt
import warnings

import numpy as np
import scipy.sparse as ss
import torch
import dgl
from dgl import function as fn

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups

from lda_model import LatentDirichletAllocation as LDAModel

n_samples = 2000
n_features = 1000
n_components = 10
n_top_words = 20
device = 'cuda'

def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1]
top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind]

ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}',
fontdict={'fontsize': 30})
ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20)
for i in 'top right left'.split():
ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40)

plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show()

# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in
# only one document or in at least 95% of the documents are removed.

print("Loading dataset...")
t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1,
remove=('headers', 'footers', 'quotes'),
return_X_y=True)
data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples]
print("done in %0.3fs." % (time() - t0))

# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
max_features=n_features,
stop_words='english')
t0 = time()
tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test)

tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v)
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)]
tt_uv = [(u,v)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)]
print("done in %0.3fs." % (time() - t0))
print()

print("Preparing dgl graphs...")
t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0))
print()

print("Training dgl-lda model...")
t0 = time()
model = LDAModel(G, n_components)
model.fit(G)
print("done in %0.3fs." % (time() - t0))
print()

print(f"dgl-lda training perplexity {model.perplexity(G):.3f}")
print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")

plot_top_words(
type('dummy', (object,), {'components_': G.ndata['z']['word'].cpu().numpy().T}),
tf_feature_names, n_top_words, 'Topics in LDA model')

print("Training scikit-learn model...")

print('\n' * 2, "Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..."
% (n_samples, n_features))
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5,
learning_method='online',
learning_offset=50.,
random_state=0,
verbose=1,
)
t0 = time()
lda.fit(tf)
print("done in %0.3fs." % (time() - t0))
print()

print(f"scikit-learn training perplexity {lda.perplexity(tf):.3f}")
print(f"scikit-learn testing perplexity {lda.perplexity(tt):.3f}")
Loading

0 comments on commit c018436

Please sign in to comment.