forked from TsingZ0/TLSAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_dataset.py
64 lines (55 loc) · 2.08 KB
/
build_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import random
import pickle
import numpy as np
import copy
max_length = 90
random.seed(1234)
with open('../Data/Digital_Music.pkl', 'rb') as f:
reviews_df, meta_df = pickle.load(f)
item_cate_list = pickle.load(f)
user_count, item_count, cate_count, example_count = pickle.load(f)
train_set = []
test_set = []
for reviewerID, hist in reviews_df.groupby('reviewerID'):
pos_list = hist['asin'].tolist()
tim_list = hist['unixReviewTime'].tolist()
def gen_neg():
neg = pos_list[0]
while neg in pos_list:
neg = random.randint(0, item_count-1)
return neg
neg_list = [gen_neg() for i in range(len(pos_list))]
length = len(pos_list)
valid_length = min(length, max_length)
i = 0
tim_list_session = list(set(tim_list))
tim_list_session.sort()
pre_session = []
for t in tim_list_session:
count = tim_list.count(t)
new_session = pos_list[i:i+count]
if t == tim_list_session[0]:
pre_session.extend(new_session)
else:
if i+count < valid_length-1:
pre_session_copy = copy.deepcopy(pre_session)
train_set.append((reviewerID, pre_session_copy, new_session, pos_list[i+count], 1))
train_set.append((reviewerID, pre_session_copy, new_session, neg_list[i+count], 0))
pre_session.extend(new_session)
else:
pos_item = pos_list[i]
if count > 1:
pos_item = random.choice(new_session)
new_session.remove(pos_item)
neg_index = pos_list.index(pos_item)
pos_neg = (pos_item, neg_list[neg_index])
test_set.append((reviewerID, pre_session, new_session, pos_neg))
break
i += count
random.shuffle(train_set)
random.shuffle(test_set)
assert len(test_set) == user_count
with open('dataset.pkl', 'wb') as f:
pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)
pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)
pickle.dump((user_count, item_count), f, pickle.HIGHEST_PROTOCOL)