Skip to content

Commit

Permalink
enable use of unigram frequency for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
meereeum committed Dec 11, 2016
1 parent c2cd8de commit d0d8ca2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
18 changes: 10 additions & 8 deletions lda2vec/negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class NegativeSampling():
# IGNORE_LABEL_MAX = 1 # ignore any labels <=1 (OOV or skip)

def __init__(self, embedding_size, vocabulary_size, sample_size, power=1.,
W_in=None):
freqs=None, W_in=None):
self.vocab_size = vocabulary_size
self.sample_size = sample_size
self.power = power
self.freqs = freqs

# via https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/examples/tutorials/word2vec/word2vec_basic.py

Expand Down Expand Up @@ -82,17 +83,18 @@ def __call__(self, embed, train_labels):
# time we evaluate the loss.
# By default this uses a log-uniform (Zipfian) distribution for sampling
# and therefore assumes labels are sorted - which they are!
# sampler = tf.nn.fixed_unigram_candidate_sampler(
# train_labels, num_true=1, num_sampled=self.sample_size,
# unique=True, range_max=self.vocab_size,
# num_reserved_ids=[0,1], # skip or OoV
# distortion=self.power, unigrams) # TODO
sampler = (freqs if freqs is None else
tf.nn.fixed_unigram_candidate_sampler(
train_labels, num_true=1, num_sampled=self.sample_size,
unique=True, range_max=self.vocab_size,
num_reserved_ids=[0,1], # skip or OoV
distortion=self.power, unigrams=freqs)) # TODO

loss = tf.reduce_mean(
tf.nn.nce_loss(self.nce_weights, self.nce_biases,
embed, # summed doc and context embedding
train_labels, self.sample_size, self.vocab_size),
# sampled_values=sampler), # log-unigram if not specificed
train_labels, self.sample_size, self.vocab_size,
sampled_values=sampler), # log-unigram if not specificed
name="nce_batch_loss")
# TODO negative sampling versus NCE
# TODO uniform vs. Zipf with exponent `distortion` param
Expand Down
5 changes: 4 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class LDA2Vec():
"temperature": 1., # embed mixture temp
"lmbda": 200., # strength of Dirichlet prior
"alpha": None, # alpha of Dirichlet process (defaults to 1/n_topics)

"freqs": None
}
RESTORE_KEY = "to_restore"

Expand All @@ -49,7 +51,8 @@ def __init__(self, n_documents=None, n_vocab=None, d_hyperparams={},#train=True,
n_documents, self.n_document_topics, self.n_embedding,
temperature=self.temperature)
self.sampler = NegativeSampling(
self.n_embedding, n_vocab, self.n_samples, power=self.power)
self.n_embedding, n_vocab, self.n_samples, power=self.power,
freqs=self.freqs)

handles = self._buildGraph() + (
self.mixture(), self.mixture.proportions(softmax=True),
Expand Down

0 comments on commit d0d8ca2

Please sign in to comment.