Skip to content

Commit

Permalink
Add method to update internal hyperparameters for FBGEMM TBE (pytorch…
Browse files Browse the repository at this point in the history
…#3025)

Summary:
X-link: facebookresearch/FBGEMM#120

Pull Request resolved: pytorch#3025

D61634495 would allow to use a hyperparameter schedule to change the hyperparameter values during training (just like learning rate schedule). In order to do so, we define a method `update_hyper_parameters` to update the internal parameter values.

Reviewed By: minddrummer, spcyppt

Differential Revision: D61642016

fbshipit-source-id: 663182d53681f8fffe97a2bcab26fc8ba8d9b89c
  • Loading branch information
Wang Zhou authored and facebook-github-bot committed Sep 4, 2024
1 parent 55721b3 commit 95a5a76
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,33 @@ def set_learning_rate(self, lr: float) -> None:
)
self._set_learning_rate(lr)

@torch.jit.ignore
def update_hyper_parameters(self, params_dict: Dict[str, float]) -> None:
"""
Sets hyper-parameters from external control flow.
"""
if self.optimizer == OptimType.NONE:
raise NotImplementedError(
f"Setting learning rate is not supported for {self.optimizer}"
)
for parameter_name, value in params_dict.items():
if parameter_name == "lr":
self._set_learning_rate(value)
elif parameter_name == "eps":
self.optimizer_args = self.optimizer_args._replace(eps=value)
elif parameter_name == "beta1":
self.optimizer_args = self.optimizer_args._replace(beta1=value)
elif parameter_name == "beta2":
self.optimizer_args = self.optimizer_args._replace(beta2=value)
elif parameter_name == "weight_decay":
self.optimizer_args = self.optimizer_args._replace(weight_decay=value)
elif parameter_name == "lower_bound":
self.gwd_lower_bound = value
else:
raise NotImplementedError(
f"Setting hyper-parameter {parameter_name} is not supported"
)

@torch.jit.ignore
def _set_learning_rate(self, lr: float) -> float:
"""
Expand Down
54 changes: 54 additions & 0 deletions fbgemm_gpu/test/tbe/utils/split_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,60 @@ def check_weight_momentum(v: int) -> None:

check_weight_momentum(0)

@unittest.skipIf(*gpu_unavailable)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_update_hyper_parameters(self) -> None:
# Create an abstract split table
D = 8
T = 2
E = 10**2
Ds = [D] * T
Es = [E] * T

hyperparameters = {
"eps": 0.1,
"beta1": 0.9,
"beta2": 0.999,
"weight_decay": 0.0,
}
cc = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(
E,
D,
EmbeddingLocation.DEVICE,
ComputeDevice.CUDA,
)
for (E, D) in zip(Es, Ds)
],
learning_rate=0.1,
**hyperparameters, # pyre-ignore[6]
)

# Update hyperparameters
updated_parameters = {
key: value + 1.0 for key, value in hyperparameters.items()
} | {"lr": 1.0, "lower_bound": 2.0}
cc.update_hyper_parameters(updated_parameters)
self.assertAlmostEqual(
cc.optimizer_args.learning_rate, updated_parameters["lr"]
)
self.assertAlmostEqual(cc.optimizer_args.eps, updated_parameters["eps"])
self.assertAlmostEqual(cc.optimizer_args.beta1, updated_parameters["beta1"])
self.assertAlmostEqual(cc.optimizer_args.beta2, updated_parameters["beta2"])
self.assertAlmostEqual(
cc.optimizer_args.weight_decay, updated_parameters["weight_decay"]
)
self.assertAlmostEqual(cc.gwd_lower_bound, updated_parameters["lower_bound"])

# Update hyperparameters with invalid parameter name
invalid_parameter = "invalid_parameter"
with self.assertRaisesRegex(
NotImplementedError,
f"Setting hyper-parameter {invalid_parameter} is not supported",
):
cc.update_hyper_parameters({"invalid_parameter": 1.0})


if __name__ == "__main__":
unittest.main()

0 comments on commit 95a5a76

Please sign in to comment.