Skip to content

Commit

Permalink
Fix integration bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
viswavi committed Oct 13, 2023
1 parent d24a8af commit 6afbaa4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def query(self, i, j):
cache_row = None
try:
start = time.perf_counter()
breakpoint()
response = call_chatgpt(prompt, self.num_predictions, temperature=1.0, max_tokens=1, timeout=2.0)
print(f"response took {round(time.perf_counter()-start, 2)} seconds")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .copkmeans import COPKMeans
from .pckmeans import PCKMeans
from .cc_pckmeans import CardinalityConstrainedPCKMeans
from .gptclustering import GPTExpansionClustering, create_single_block_for_prompt
from .gptclustering import GPTExpansionClustering
from .kmeans_corrector import KMeansCorrection
from .mpckmeans import MPCKMeans
from .mpckmeansmf import MPCKMeansMF
from .mkmeans import MKMeans
from .rcakmeans import RCAKMeans
from .sccl import SCCL, DeepSCCL
from .kmeans_corrector import KMeansCorrection
from .sccl import SCCL, DeepSCCL
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def __init__(self, X, documents, encoder_model=None, dataset_name=None, prompt=N
self.keep_original_entity = keep_original_entity
self.n_clusters = n_clusters
self.side_information = side_information
self.cache_dir = "/projects/ogma2/users/vijayv/extra_storage/okb-canonicalization/clustering/file/gpt3_cache"
cache_file = os.path.join(self.cache_dir, cache_file_name)
cache_file = cache_file_name
self.instruction_only = instruction_only
self.demonstration_only = demonstration_only
if instruction_only:
Expand All @@ -61,7 +60,6 @@ def __init__(self, X, documents, encoder_model=None, dataset_name=None, prompt=N
self.read_only = read_only

split_str = f"_{split}" if split else ""
self.sentence_unprocessing_mapping_file = os.path.join(self.cache_dir, f"{dataset_name}{split_str}_sentence_unprocessing_map.json")

def process_sentence_punctuation(self, sentences):
processed_sentence_set = []
Expand Down Expand Up @@ -106,8 +104,6 @@ def fit(self, X):
template_to_fill = self.construct_gpt3_template(doc_idx, instruction_only=self.instruction_only, demonstration_only=self.demonstration_only)
print(f"PROMPT:\n{template_to_fill}")

breakpoint()

failure = True
num_retries = 0
while failure and num_retries < self.NUM_RETRIES:
Expand Down Expand Up @@ -169,7 +165,7 @@ def fit(self, X):
if self.prompt_for_encoder is None:
expansion_embeddings = self.encoder_model.encode(all_expansions)
else:
expansion_embeddings = model.encode([[self.prompt_for_encoder, text] for text in all_expansions])
expansion_embeddings = self.encoder_model.encode([[self.prompt_for_encoder, text] for text in all_expansions])

a_vectors = normalize(self.X, axis=1, norm="l2")
b_vectors = normalize(expansion_embeddings, axis=1, norm="l2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fit(self, X, y=None, ml=[], cl=[]):
elapsed = time.perf_counter() - start
print(f"elapsed time: {round(elapsed, 3)}")

if not isinstance(self.side_information, list) and iteration % 10 == 0:
if self.side_information is not None and not isinstance(self.side_information, list) and iteration % 10 == 0:
ave_prec, ave_recall, ave_f1, macro_prec, micro_prec, pair_prec, macro_recall, micro_recall, pair_recall, macro_f1, micro_f1, pairwise_f1, model_clusters, model_Singletons, gold_clusters, gold_Singletons = cluster_test(self.side_information.p, self.side_information.side_info, labels, self.side_information.true_ent2clust, self.side_information.true_clust2ent)
metric_dict = {"macro_f1": macro_f1, "micro_f1": micro_f1, "pairwise_f1": pairwise_f1, "ave_f1": ave_f1}
print(f"metric_dict at iteration {iteration}:\t{metric_dict}")
Expand Down

0 comments on commit 6afbaa4

Please sign in to comment.