diff --git a/README.md b/README.md index 4a3a6e35a..400fa3b86 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | | [Visual Bayesian Personalized Ranking (VBPR)](cornac/models/vbpr), [paper](https://arxiv.org/pdf/1510.01784.pdf) | [requirements.txt](cornac/models/vbpr/requirements.txt) | [vbpr_tradesy.py](examples/vbpr_tradesy.py) | 2015 | [Collaborative Deep Learning (CDL)](cornac/models/cdl), [paper](https://arxiv.org/pdf/1409.2944.pdf) | [requirements.txt](cornac/models/cdl/requirements.txt) | [cdl_exp.py](examples/cdl_example.py) | | [Hierarchical Poisson Factorization (HPF)](cornac/models/hpf), [paper](http://jakehofman.com/inprint/poisson_recs.pdf) | N/A | [hpf_movielens.py](examples/hpf_movielens.py) +| | [TriRank: Review-aware Explainable Recommendation by Modeling Aspects](cornac/models/trirank), [paper](https://wing.comp.nus.edu.sg/wp-content/uploads/Publications/PDF/TriRank-%20Review-aware%20Explainable%20Recommendation%20by%20Modeling%20Aspects.pdf) | N/A | [trirank_example.py](examples/trirank_example.py) | 2014 | [Explicit Factor Model (EFM)](cornac/models/efm), [paper](http://yongfeng.me/attach/efm-zhang.pdf) | N/A | [efm_exp.py](examples/efm_example.py) | | [Social Bayesian Personalized Ranking (SBPR)](cornac/models/sbpr), [paper](https://cseweb.ucsd.edu/~jmcauley/pdfs/cikm14.pdf) | N/A | [sbpr_epinions.py](examples/sbpr_epinions.py) | 2013 | [Hidden Factors and Hidden Topics (HFT)](cornac/models/hft), [paper](https://cs.stanford.edu/people/jure/pubs/reviews-recsys13.pdf) | N/A | [hft_exp.py](examples/hft_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index ce6ca52f0..37fa12e55 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -58,6 +58,7 @@ from .skm import SKMeans from .sorec import SoRec from .svd import SVD +from .trirank import TriRank from .vaecf import VAECF from .vbpr import VBPR from .vmf import VMF diff --git a/cornac/models/trirank/__init__.py b/cornac/models/trirank/__init__.py new file mode 100644 index 000000000..9d9f22c62 --- /dev/null +++ b/cornac/models/trirank/__init__.py @@ -0,0 +1 @@ +from .recom_trirank import TriRank \ No newline at end of file diff --git a/cornac/models/trirank/recom_trirank.py b/cornac/models/trirank/recom_trirank.py new file mode 100644 index 000000000..2b87cf65b --- /dev/null +++ b/cornac/models/trirank/recom_trirank.py @@ -0,0 +1,332 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from scipy.sparse import csr_matrix +from tqdm.auto import tqdm + +from ..recommender import Recommender +from ...utils import get_rng +from ...utils.init_utils import uniform +from ...exception import ScoreException + + +EPS = 1e-10 + + +class TriRank(Recommender): + """TriRank: Review-aware Explainable Recommendation by Modeling Aspects. + + Parameters + ---------- + name: string, optional, default: 'TriRank' + The name of the recommender model. + + alpha: float, optional, default: 1 + The weight of smoothness on user-item relation + + beta: float, optional, default: 1 + The weight of smoothness on item-aspect relation + + gamma: float, optional, default: 1 + The weight of smoothness on user-aspect relation + + eta_U: float, optional, default: 1 + The weight of fitting constraint on users + + eta_P: float, optional, default: 1 + The weight of fitting constraint on items + + eta_A: float, optional, default: 1 + The weight of fitting constraint on aspects + + max_iter: int, optional, default: 100 + Maximum number of iterations to stop online training. If set to `max_iter=-1`, \ + the online training will stop when model parameters are converged. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model already \ + pre-trained (R, X, Y, p, a, u are not None). + + verbose: boolean, optional, default: False + When True, running logs are displayed. + + init_params: dictionary, optional, default: None + List of initial parameters, e.g., init_params = {'R':R, 'X':X, 'Y':Y, 'p':p, 'a':a, 'u':u} + + R: csr_matrix, shape (n_users, n_items) + The symmetric normalized of edge weight matrix of user-item relation, optional initialization via init_params + + X: csr_matrix, shape (n_users, n_aspects) + The symmetric normalized of edge weight matrix of user-aspect relation, optional initialization via init_params + + Y: csr_matrix, shape (n_items, n_aspects) + The symmetric normalized of edge weight matrix of item-aspect relation, optional initialization via init_params + + p: ndarray, shape (n_items,) + Initialized item weights, optional initialization via init_params + + a: ndarray, shape (n_aspects,) + Initialized aspect weights, optional initialization via init_params + + u: ndarray, shape (n_aspects,) + Initialized user weights, optional initialization via init_params + + seed: int, optional, default: None + Random seed for parameters initialization. + + References + ---------- + He, Xiangnan, Tao Chen, Min-Yen Kan, and Xiao Chen. 2014. \ + TriRank: Review-aware Explainable Recommendation by Modeling Aspects. \ + In the 24th ACM international on conference on information and knowledge management (CIKM'15). \ + ACM, New York, NY, USA, 1661-1670. DOI: https://doi.org/10.1145/2806416.2806504 + """ + + def __init__( + self, + name="TriRank", + alpha=1, + beta=1, + gamma=1, + eta_U=1, + eta_P=1, + eta_A=1, + max_iter=100, + verbose=True, + init_params=None, + seed=None, + ): + super().__init__(name) + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.eta_U = eta_U + self.eta_P = eta_P + self.eta_A = eta_A + self.max_iter = max_iter + self.verbose = verbose + self.seed = seed + self.rng = get_rng(seed) + + # Init params if provided + self.init_params = {} if init_params is None else init_params + self.R = self.init_params.get("R", None) + self.X = self.init_params.get("X", None) + self.Y = self.init_params.get("Y", None) + self.p = self.init_params.get("p", None) + self.a = self.init_params.get("a", None) + self.u = self.init_params.get("u", None) + + def _init(self): + # Initialize user, item and aspect rank. + if self.p is None: + self.p = uniform(self.train_set.num_items, random_state=self.rng) + if self.a is None: + self.a = uniform( + self.train_set.sentiment.num_aspects, random_state=self.rng + ) + if self.u is None: + self.u = uniform(self.train_set.num_users, random_state=self.rng) + + def _symmetrical_normalization(self, matrix: csr_matrix): + row = [] + col = [] + data = [] + row_norm = np.sqrt(matrix.sum(axis=1).A1) + col_norm = np.sqrt(matrix.sum(axis=0).A1) + for i, j in zip(*matrix.nonzero()): + row.append(i) + col.append(j) + data.append(matrix[i, j] / (row_norm[i] * col_norm[j])) + + return csr_matrix((data, (row, col)), shape=matrix.shape) + + def _create_matrices(self, train_set): + from time import time + + start_time = time() + if self.verbose: + print("Building matrices started!") + sentiment_modality = train_set.sentiment + n_users = train_set.num_users + n_items = train_set.num_items + n_aspects = sentiment_modality.num_aspects + + X_row = [] + X_col = [] + X_data = [] + Y_row = [] + Y_col = [] + Y_data = [] + for uid, isid in tqdm( + sentiment_modality.user_sentiment.items(), + disable=not self.verbose, + desc="Building matrices", + ): + for iid, sid in isid.items(): + aos = sentiment_modality.sentiment[sid] + aids = set(aid for aid, _, _ in aos) # Only one per review/sid + for aid in aids: + X_row.append(iid) + X_col.append(aid) + X_data.append(1) + Y_row.append(uid) + Y_col.append(aid) + Y_data.append(1) + + # Algorithm 1: Offline training line 2 + X = csr_matrix((X_data, (X_row, X_col)), shape=(n_items, n_aspects)) + Y = csr_matrix((Y_data, (Y_row, Y_col)), shape=(n_users, n_aspects)) + + # Algorithm 1: Offline training line 3 + X.data = np.log2(X.data) + 1 + Y.data = np.log2(Y.data) + 1 + + # Algorithm 1: Offline training line 4 + if self.verbose: + print("Building symmetric normalized matrices R, X, Y") + self.R = self._symmetrical_normalization(train_set.csr_matrix) + self.X = self._symmetrical_normalization(X) + self.Y = self._symmetrical_normalization(Y) + + if self.verbose: + total_time = time() - start_time + print("Building matrices completed in %d s" % total_time) + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + self._init() + + if not self.trainable: + return self + + # Offline training: Build item-aspect matrix X and user-aspect matrix Y + self._create_matrices(train_set) + return self + + def _online_recommendation(self, user): + # Algorithm 1: Online recommendation line 5 + p_0 = self.train_set.csr_matrix[[user]] + p_0.data.fill(1) + p_0 = p_0.toarray().squeeze() + a_0 = self.Y[user].toarray().squeeze() + u_0 = np.zeros(self.train_set.csr_matrix.shape[0]) + u_0[user] = 1 + + # Algorithm 1: Online training line 6 + if p_0.any(): + p_0 /= np.linalg.norm(p_0, 1) + if a_0.any(): + a_0 /= np.linalg.norm(a_0, 1) + if u_0.any(): + u_0 /= np.linalg.norm(u_0, 1) + + # Algorithm 1: Online recommendation line 7 + p = self.p.copy() + a = self.a.copy() + u = self.u.copy() + + # Algorithm 1: Online recommendation line 8 + prev_p = p + prev_a = a + prev_u = u + inc = 1 + while True: + # eq. 4 + u_denominator = self.alpha + self.gamma + self.eta_U + EPS + u = ( + self.alpha / u_denominator * self.R * p + + self.gamma / u_denominator * self.Y * a + + self.eta_U / u_denominator * u_0 + ).squeeze() + p_denominator = self.alpha + self.beta + self.eta_P + EPS + p = ( + self.alpha / p_denominator * self.R.T * u + + self.beta / p_denominator * self.X * a + + self.eta_P / p_denominator * p_0 + ).squeeze() + a_denominator = self.gamma + self.beta + self.eta_A + EPS + a = ( + self.gamma / a_denominator * self.Y.T * u + + self.beta / a_denominator * self.X.T * p + + self.eta_P / a_denominator * a_0 + ).squeeze() + + if (self.max_iter > 0 and inc > self.max_iter) or ( + np.all(np.isclose(u, prev_u)) + and np.all(np.isclose(p, prev_p)) + and np.all(np.isclose(a, prev_a)) + ): # stop when converged + break + prev_p, prev_a, prev_u = p, a, u + inc += 1 + + # Algorithm 1: Online recommendation line 9 + return p, a, u + + def score(self, u_idx, i_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + u_idx: int, required + The index of the user for whom to perform score prediction. + + i_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + if self.train_set.is_unk_user(u_idx): + raise ScoreException("Can't make score prediction for (user_id=%d" & u_idx) + if i_idx is not None and self.train_set.is_unk_item(i_idx): + raise ScoreException("Can't make score prediction for (item_id=%d" & i_idx) + + item_scores, *_ = self._online_recommendation(u_idx) + # Set already rated items to zero. + item_scores[self.train_set.csr_matrix[u_idx].indices] = 0 + + # Scale to match rating scale. + item_scores = ( + item_scores + * (self.train_set.max_rating - self.train_set.min_rating) + / max(item_scores) + + self.train_set.min_rating + ) + + if i_idx is None: + return item_scores + else: + return item_scores[i_idx] diff --git a/docs/source/models.rst b/docs/source/models.rst index 33cf8d363..aeb06d2f0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -153,6 +153,11 @@ Hierarchical Poisson Factorization (HPF) .. automodule:: cornac.models.hpf.recom_hpf :members: +TriRank: Review-aware Explainable Recommendation by Modeling Aspects (TriRank) +-------------------------------------------- +.. automodule:: cornac.models.trirank.recom_trirank + :members: + Explicit Factor Model (EFM) -------------------------------------------- .. automodule:: cornac.models.efm.recom_efm diff --git a/examples/README.md b/examples/README.md index 91530dadf..0b09584d7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -50,6 +50,8 @@ [cvae_example.py](cvae_example.py) - Collaborative Variational Autoencoder (CVAE) with CiteULike dataset. +[trirank_example.py](trirank_example.py) - TriRank with Amazon Toy and Games dataset. + [efm_example.py](efm_example.py) - Explicit Factor Model (EFM) with Amazon Toy and Games dataset. [hft_example.py](hft_example.py) - Hidden Factor Topic (HFT) with MovieLen 1m dataset. diff --git a/examples/trirank_example.py b/examples/trirank_example.py new file mode 100644 index 000000000..62e9ede15 --- /dev/null +++ b/examples/trirank_example.py @@ -0,0 +1,53 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TriRank: Review-aware Explainable Recommendation by Modeling Aspects""" + +import cornac +from cornac.datasets import amazon_toy +from cornac.data import SentimentModality +from cornac.eval_methods import RatioSplit + + +# Load rating and sentiment information +rating = amazon_toy.load_feedback() +sentiment = amazon_toy.load_sentiment() + +# Instantiate a SentimentModality, it makes it convenient to work with sentiment information +md = SentimentModality(data=sentiment) + +# Define an evaluation method to split feedback into train and test sets +eval_method = RatioSplit( + data=rating, + test_size=0.15, + exclude_unknowns=True, + verbose=True, + sentiment=md, + seed=123, +) + +# Instantiate the model +trirank = cornac.models.TriRank( + verbose=True, + seed=123, +) + +# Instantiate evaluation metrics +ndcg_50 = cornac.metrics.NDCG(k=50) +auc = cornac.metrics.AUC() + +# Put everything together into an experiment and run it +cornac.Experiment( + eval_method=eval_method, models=[trirank], metrics=[ndcg_50, auc] +).run()