Skip to content

Commit

Permalink
[Bugfix] Fix GCMC examples (dmlc#4082)
Browse files Browse the repository at this point in the history
* [Example][Bug] Fix GCMC examples

* Revert the change of model.py from #74f01405
  • Loading branch information
chang-l authored Jun 4, 2022
1 parent 0f2ff47 commit d31448d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions examples/pytorch/gcmc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def _process_movie_fea(self):
"""
import torchtext
from torchtext.data.utils import get_tokenizer

if self._name == 'ml-100k':
GENRES = GENRES_ML_100K
Expand All @@ -514,7 +515,9 @@ def _process_movie_fea(self):
else:
raise NotImplementedError

TEXT = torchtext.legacy.data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
# Old torchtext-legacy API commented below
# TEXT = torchtext.legacy.data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
tokenizer = get_tokenizer('spacy', language='en_core_web_sm') # new API (torchtext 0.9+)
embedding = torchtext.vocab.GloVe(name='840B', dim=300)

title_embedding = np.zeros(shape=(self.movie_info.shape[0], 300), dtype=np.float32)
Expand All @@ -528,7 +531,8 @@ def _process_movie_fea(self):
else:
title_context, year = match_res.groups()
# We use average of glove
title_embedding[i, :] = embedding.get_vecs_by_tokens(TEXT.tokenize(title_context)).numpy().mean(axis=0)
# Upgraded torchtext API: TEXT.tokenize(title_context) --> tokenizer(title_context)
title_embedding[i, :] = embedding.get_vecs_by_tokens(tokenizer(title_context)).numpy().mean(axis=0)
release_years[i] = float(year)
movie_features = np.concatenate((title_embedding,
(release_years - 1950.0) / 100.0,
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/gcmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, graph, feat, weight=None):
if weight is not None:
feat = dot_or_identity(feat, weight, self.device)

feat = feat * self.dropout(cj).view(-1, 1)
feat = feat * self.dropout(cj)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
Expand Down Expand Up @@ -342,7 +342,7 @@ def forward(self, graph, ufeat, ifeat):
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
basis_out.append(graph.edata['sr'])
out = th.cat(basis_out, dim=1)
#out = self.combine_basis(out)
out = self.combine_basis(out)
return out

class DenseBiDecoder(nn.Module):
Expand Down

0 comments on commit d31448d

Please sign in to comment.