Skip to content

Commit

Permalink
Reduce global learning rate by 2x
Browse files Browse the repository at this point in the history
  • Loading branch information
viswavi committed Jun 5, 2023
1 parent e688d82 commit d05d231
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, X, labels, max_queries_cnt=2500, num_predictions=5, side_info
self.num_predictions = num_predictions

self.side_information = side_information
self.cache_dir = "/projects/ogma1/vijayv/okb-canonicalization/clustering/file/gpt3_cache"
self.cache_dir = "/home/vijayv/okb-canonicalization/clustering/file/gpt3_cache"
self.cache_file = os.path.join(self.cache_dir, "pairwise_constraint_cache_prompt_engineered_classifier_oracle_free_selector_no_duplicate_pairs.jsonl")
if os.path.exists(self.cache_file):
self.cache_rows = list(jsonlines.open(self.cache_file))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def __init__(self,
batch_size = 400,
linear_transformation = False,
canonicalization_side_information=None,
tensorboard_parent_dir="/projects/ogma1/vijayv/okb-canonicalization/clustering/sccl/",
tensorboard_dir="tmp"):
tensorboard_parent_dir="/projects/ogma1/vijayv/okb-canonicalization/clustering/sccl/", tensorboard_dir="tmp"):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.normalize_vectors = normalize_vectors
Expand Down Expand Up @@ -82,8 +81,6 @@ def fit(self, X, pairwise_constraints):

train_loader = util_data.DataLoader(X.astype("float32"), batch_size=self.batch_size, shuffle=True, num_workers=1)

breakpoint()

model = SCCLMatrix(emb_size=X.shape[1], cluster_centers=cluster_centers, include_contrastive_loss=self.include_contrastive_loss, linear_transformation = self.linear_transformation)
model = model.cuda()

Expand All @@ -95,7 +92,7 @@ def fit(self, X, pairwise_constraints):

# optimize
Args = namedtuple("Args", "lr lr_scale lr_scale_scl eta temperature objective print_freq max_iter batch_size tensorboard")
args = Args(1e-05, 100, 100, 10, 0.5, "SCCL", 300, 2500 * X.shape[0] / self.batch_size, self.batch_size, tensorboard)
args = Args(5e-06, 100, 100, 10, 0.5, "SCCL", 300, 2000 * X.shape[0] / self.batch_size, self.batch_size, tensorboard)


optimizer = get_optimizer_linear_transformation(model, args, include_contrastive_loss=self.include_contrastive_loss, linear_transformation=self.linear_transformation)
Expand Down

0 comments on commit d05d231

Please sign in to comment.