Skip to content

Commit

Permalink
add comments for recurrencybaseline
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaGast committed May 31, 2024
1 parent 4846e06 commit b5b7077
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 18 deletions.
16 changes: 14 additions & 2 deletions examples/linkproppred/thgl-forum/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-forum", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/thgl-github/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-github", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/thgl-myket/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-myket", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/thgl-software/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-software", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/tkgl-icews/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-icews", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/tkgl-polecat/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-polecat", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
16 changes: 14 additions & 2 deletions examples/linkproppred/tkgl-smallpedia/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""

first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
Expand Down Expand Up @@ -78,7 +83,7 @@ def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,

## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions for each relation on test or valid set and compute mrr
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler,
return perf_list_all, hits_list_all

def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
Expand Down Expand Up @@ -153,7 +162,9 @@ def read_dict_compute_mrr(split_mode='test'):

## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
Expand Down Expand Up @@ -243,6 +254,7 @@ def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_

## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-smallpedia", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
Expand Down
Loading

0 comments on commit b5b7077

Please sign in to comment.