Skip to content

Commit

Permalink
refactorization
Browse files Browse the repository at this point in the history
  • Loading branch information
XU-YaoKun committed Jan 22, 2020
1 parent 6fc1380 commit 55782cd
Show file tree
Hide file tree
Showing 20 changed files with 439 additions and 403 deletions.
File renamed without changes.
File renamed without changes.
100 changes: 100 additions & 0 deletions common/config/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse


def parse_args():
parser = argparse.ArgumentParser(description="Run KGPolicy2.")
# ------------------------- experimental settings specific for data set --------------------------------------------
parser.add_argument(
"--data_path", nargs="?", default="../Data/", help="Input data path."
)
parser.add_argument(
"--dataset", nargs="?", default="last-fm", help="Choose a dataset."
)
parser.add_argument("--emb_size", type=int, default=64, help="Embedding size.")
parser.add_argument(
"--regs",
nargs="?",
default="1e-5",
help="Regularization for user and item embeddings.",
)
parser.add_argument("--gpu_id", type=int, default=0, help="gpu id")
parser.add_argument(
"--k_neg", type=int, default=1, help="number of negative items in list"
)

# ------------------------- experimental settings specific for recommender -----------------------------------------
parser.add_argument(
"--slr", type=float, default=0.0001, help="Learning rate for sampler."
)
parser.add_argument(
"--rlr", type=float, default=0.0001, help="Learning rate recommender."
)

# ------------------------- experimental settings specific for sampler ---------------------------------------------
parser.add_argument(
"--edge_threshold",
type=int,
default=64,
help="edge threshold to filter knowledge graph",
)
parser.add_argument(
"--num_sample", type=int, default=32, help="number fo samples from gcn"
)
parser.add_argument(
"--k_step", type=int, default=2, help="k step from current positive items"
)
parser.add_argument(
"--in_channel", type=str, default="[64, 32]", help="input channels for gcn"
)
parser.add_argument(
"--out_channel", type=str, default="[32, 64]", help="output channels for gcn"
)
parser.add_argument(
"--pretrain_s",
type=bool,
default=False,
help="load pretrained sampler data or not",
)

# ------------------------- experimental settings specific for training --------------------------------------------
parser.add_argument(
"--batch_size", type=int, default=1024, help="batch size for training."
)
parser.add_argument(
"--test_batch_size", type=int, default=1024, help="batch size for test"
)
parser.add_argument("--num_threads", type=int, default=4, help="number of threads.")
parser.add_argument("--epoch", type=int, default=400, help="Number of epoch.")
parser.add_argument("--show_step", type=int, default=3, help="test step.")
parser.add_argument(
"--adj_epoch", type=int, default=1, help="build adj matrix per _ epoch"
)
parser.add_argument(
"--pretrain_r", type=bool, default=True, help="use pretrained model or not"
)
parser.add_argument(
"--freeze_s",
type=bool,
default=False,
help="freeze parameters of recommender or not",
)
parser.add_argument(
"--model_path",
type=str,
default="model/best_fm.ckpt",
help="path for pretrain model",
)
parser.add_argument(
"--out_dir", type=str, default="./weights/", help="output directory for model"
)
parser.add_argument("--flag_step", type=int, default=32, help="early stop steps")
parser.add_argument(
"--gamma", type=float, default=0.99, help="gamma for reward accumulation"
)

# ------------------------- experimental settings specific for testing ---------------------------------------------
parser.add_argument(
"--Ks", nargs="?", default="[20, 40, 60, 80, 100]", help="evaluate K list"
)

return parser.parse_args()
File renamed without changes.
22 changes: 22 additions & 0 deletions common/dataset/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.utils.data import DataLoader
from common.dataset.dataset import TrainGenerator, TestGenerator


def build_loader(args_config, graph):
train_generator = TrainGenerator(args_config=args_config, graph=graph)
train_loader = DataLoader(
train_generator,
batch_size=args_config.batch_size,
shuffle=True,
num_workers=args_config.num_threads,
)

test_generator = TestGenerator(args_config=args_config, graph=graph)
test_loader = DataLoader(
test_generator,
batch_size=args_config.test_batch_size,
shuffle=False,
num_workers=args_config.num_threads,
)

return train_loader, test_loader
12 changes: 7 additions & 5 deletions dataloader/data_generator.py → common/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __getitem__(self, index):
user_dict = self.user_dict
# randomly select one user.
u_id = random.sample(self.exist_users, 1)[0]
out_dict['u_id'] = u_id
out_dict["u_id"] = u_id

# randomly select one positive item.
pos_items = user_dict[u_id]
Expand All @@ -35,16 +35,18 @@ def __getitem__(self, index):
pos_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
pos_i_id = pos_items[pos_idx]

out_dict['pos_i_id'] = pos_i_id
out_dict["pos_i_id"] = pos_i_id

neg_i_id = self.get_random_neg(pos_items, [])
out_dict['neg_i_id'] = neg_i_id
out_dict["neg_i_id"] = neg_i_id

return out_dict

def get_random_neg(self, pos_items, selected_items):
while True:
neg_i_id = np.random.randint(low=self.low_item_index, high=self.high_item_index, size=1)[0]
neg_i_id = np.random.randint(
low=self.low_item_index, high=self.high_item_index, size=1
)[0]

if neg_i_id not in pos_items and neg_i_id not in selected_items:
break
Expand All @@ -63,6 +65,6 @@ def __getitem__(self, index):
batch_data = {}

u_id = self.users_to_test[index]
batch_data['u_id'] = u_id
batch_data["u_id"] = u_id

return batch_data
74 changes: 42 additions & 32 deletions dataloader/data_processor.py → common/dataset/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import collections
import numpy as np
import networkx as nx
import pickle
import os
from tqdm import tqdm
from utility.helper import ensure_dir
from time import time


class CFData(object):
def __init__(self, args_config):
self.args_config = args_config

path = args_config.data_path + args_config.dataset
train_file = path + '/train.dat'
test_file = path + '/test.dat'
train_file = path + "/train.dat"
test_file = path + "/test.dat"

# ----------get number of users and items & then load rating data from train_file & test_file------------
self.train_data = self._generate_interactions(train_file)
Expand All @@ -30,10 +26,10 @@ def __init__(self, args_config):
def _generate_interactions(file_name):
inter_mat = list()

lines = open(file_name, 'r').readlines()
lines = open(file_name, "r").readlines()
for l in lines:
tmps = l.strip()
inters = [int(i) for i in tmps.split(' ')]
inters = [int(i) for i in tmps.split(" ")]

u_id, pos_ids = inters[0], inters[1:]
pos_ids = list(set(pos_ids))
Expand Down Expand Up @@ -70,19 +66,23 @@ def _id_range(train_mat, test_mat, idx):
n_id = max_id - min_id + 1
return (min_id, max_id), n_id

self.user_range, self.n_users = _id_range(self.train_data, self.test_data, idx=0)
self.item_range, self.n_items = _id_range(self.train_data, self.test_data, idx=1)
self.user_range, self.n_users = _id_range(
self.train_data, self.test_data, idx=0
)
self.item_range, self.n_items = _id_range(
self.train_data, self.test_data, idx=1
)
self.n_train = len(self.train_data)
self.n_test = len(self.test_data)

print('-'*50)
print('- user_range: (%d, %d)' % (self.user_range[0], self.user_range[1]))
print('- item_range: (%d, %d)' % (self.item_range[0], self.item_range[1]))
print('- n_train: %d' % self.n_train)
print('- n_test: %d' % self.n_test)
print('- n_users: %d' % self.n_users)
print('- n_items: %d' % self.n_items)
print('-'*50)
print("-" * 50)
print("- user_range: (%d, %d)" % (self.user_range[0], self.user_range[1]))
print("- item_range: (%d, %d)" % (self.item_range[0], self.item_range[1]))
print("- n_train: %d" % self.n_train)
print("- n_test: %d" % self.n_test)
print("- n_users: %d" % self.n_users)
print("- n_items: %d" % self.n_items)
print("-" * 50)


class KGData(object):
Expand All @@ -92,7 +92,7 @@ def __init__(self, args_config, entity_start_id=0, relation_start_id=0):
self.relation_start_id = relation_start_id

path = args_config.data_path + args_config.dataset
kg_file = path + '/kg_final.txt'
kg_file = path + "/kg_final.txt"

# ----------get number of entities and relations & then load kg data from kg_file ------------.
self.kg_data, self.kg_dict, self.relation_dict = self._load_kg(kg_file)
Expand Down Expand Up @@ -140,30 +140,40 @@ def _construct_kg(kg_np):

def _statistic_kg_triples(self):
def _id_range(kg_mat, idx):
min_id = min(min(kg_mat[:, idx]), min(kg_mat[:, 2-idx]))
max_id = max(max(kg_mat[:, idx]), max(kg_mat[:, 2-idx]))
min_id = min(min(kg_mat[:, idx]), min(kg_mat[:, 2 - idx]))
max_id = max(max(kg_mat[:, idx]), max(kg_mat[:, 2 - idx]))
n_id = max_id - min_id + 1
return (min_id, max_id), n_id

self.entity_range, self.n_entities = _id_range(self.kg_data, idx=0)
self.relation_range, self.n_relations = _id_range(self.kg_data, idx=1)
self.n_kg_triples = len(self.kg_data)

print('-'*50)
print('- entity_range: (%d, %d)' % (self.entity_range[0], self.entity_range[1]))
print('- relation_range: (%d, %d)' % (self.relation_range[0], self.relation_range[1]))
print('- n_entities: %d' % self.n_entities)
print('- n_relations: %d' % self.n_relations)
print('- n_kg_triples: %d' % self.n_kg_triples)
print('-'*50)
print("-" * 50)
print(
"- entity_range: (%d, %d)" % (self.entity_range[0], self.entity_range[1])
)
print(
"- relation_range: (%d, %d)"
% (self.relation_range[0], self.relation_range[1])
)
print("- n_entities: %d" % self.n_entities)
print("- n_relations: %d" % self.n_relations)
print("- n_kg_triples: %d" % self.n_kg_triples)
print("-" * 50)


class CKGData(CFData, KGData):
def __init__(self, args_config):
CFData.__init__(self, args_config=args_config)
KGData.__init__(self, args_config=args_config, entity_start_id=self.n_users, relation_start_id=2)
KGData.__init__(
self,
args_config=args_config,
entity_start_id=self.n_users,
relation_start_id=2,
)
self.args_config = args_config

self.ckg_graph = self._combine_cf_kg()

def _combine_cf_kg(self):
Expand All @@ -176,12 +186,12 @@ def _combine_cf_kg(self):
# ... ids of other entities in range of [#users + #items, #users + #entities)
# ... ids of relations in range of [0, 2 + 2 * #kg relations), including two 'interact' and 'interacted_by'.
ckg_graph = nx.MultiDiGraph()
print('Begin to load interaction triples ...')
print("Begin to load interaction triples ...")
for u_id, i_id in tqdm(cf_mat, ascii=True):
ckg_graph.add_edges_from([(u_id, i_id)], r_id=0)
ckg_graph.add_edges_from([(i_id, u_id)], r_id=1)

print('\nBegin to load knowledge graph triples ...')
print("\nBegin to load knowledge graph triples ...")
for h_id, r_id, t_id in tqdm(kg_mat, ascii=True):
ckg_graph.add_edges_from([(h_id, t_id)], r_id=r_id)
return ckg_graph
Loading

0 comments on commit 55782cd

Please sign in to comment.