Skip to content

Commit

Permalink
Update DreamRec.py
Browse files Browse the repository at this point in the history
  • Loading branch information
YangZhengyi98 authored Oct 23, 2023
1 parent 18966c7 commit ad1d452
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions DreamRec.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,10 @@ def evaluate(model, test_data, diff, w, device):

total_purchase+=batch_size

print('#############################################################')

hr_list = []
ndcg_list = []
print('hr@{}\tndcg@{}\thr@{}\tndcg@{}\thr@{}\tndcg@{}'.format(topk[0], topk[0], topk[1], topk[1], topk[2], topk[2]))
print('{:<10s} {:<10s} {:<10s} {:<10s} {:<10s} {:<10s}'.format('HR@'+str(topk[0]), 'NDCG@'+str(topk[0]), 'HR@'+str(topk[1]), 'NDCG@'+str(topk[1]), 'HR@'+str(topk[2]), 'NDCG@'+str(topk[2])))
for i in range(len(topk)):
hr_purchase=hit_purchase[i]/total_purchase
ng_purchase=ndcg_purchase[i]/total_purchase
Expand All @@ -469,9 +469,9 @@ def evaluate(model, test_data, diff, w, device):
if i == 1:
hr_20 = hr_purchase

print('{:.6f}\t{:.6f}\t{:.6f}\t{:.6f}\t{:.6f}\t{:.6f}'.format(hr_list[0], (ndcg_list[0]), hr_list[1], (ndcg_list[1]), hr_list[2], (ndcg_list[2])))
print('{:.4f}&{:.4f}&{:.4f}&{:.4f}&{:.4f}&{:.4f}'.format(hr_list[0], (ndcg_list[0]), hr_list[1], (ndcg_list[1]), hr_list[2], (ndcg_list[2])))
print('#############################################################')
print('{:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f}'.format(hr_list[0], (ndcg_list[0]), hr_list[1], (ndcg_list[1]), hr_list[2], (ndcg_list[2])))



return hr_20

Expand All @@ -482,8 +482,6 @@ def evaluate(model, test_data, diff, w, device):
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)

data_directory = './data/' + args.data
# data_directory = './data/' + args.data
# data_directory = '../' + args.data + '/data'
data_statis = pd.read_pickle(
os.path.join(data_directory, 'data_statis.df')) # read data statistics, includeing seq_size and item_num
seq_size = data_statis['seq_size'][0] # the length of history to define the seq
Expand All @@ -496,7 +494,6 @@ def evaluate(model, test_data, diff, w, device):
timesteps = args.timesteps


# model = GRU(args.hidden_factor,item_num, seq_size, args.layers)
model = Tenc(args.hidden_factor,item_num, seq_size, args.dropout_rate, args.diffuser_type, device)
diff = diffusion(args.timesteps, args.beta_start, args.beta_end)

Expand Down Expand Up @@ -525,6 +522,7 @@ def evaluate(model, test_data, diff, w, device):
num_rows=train_data.shape[0]
num_batches=int(num_rows/args.batch_size)
for i in range(args.epoch):
start_time = Time.time()
for j in range(num_batches):
batch = train_data.sample(n=args.batch_size).to_dict()
seq = list(batch['seq'].values())
Expand Down Expand Up @@ -555,17 +553,18 @@ def evaluate(model, test_data, diff, w, device):

if args.report_epoch:
if i % 1 == 0:
print("the loss in %dth epoch is: %f" % (i, loss))
print("Epoch {:03d}; ".format(i) + 'Train loss: {:.4f}; '.format(loss) + "Time cost: " + Time.strftime(
"%H: %M: %S", Time.gmtime(Time.time()-start_time)))

if (i + 1) % 10 == 0:
w_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for w in w_list:
print('VAL PHRASE{}:'.format(w))
hr_20 = evaluate(model, 'val_data.df', diff, w, device)
print('TEST PHRASE{}:'.format(w))
_ = evaluate(model, 'test_data.df', diff, w, device)

eval_start = Time.time()
print('-------------------------- VAL PHRASE --------------------------')
_ = evaluate(model, 'val_data.df', diff, args.w, device)
print('-------------------------- TEST PHRASE -------------------------')
_ = evaluate(model, 'test_data.df', diff, args.w, device)
print("Evalution cost: " + Time.strftime("%H: %M: %S", Time.gmtime(Time.time()-eval_start)))
print('----------------------------------------------------------------')



Expand Down

0 comments on commit ad1d452

Please sign in to comment.