Skip to content

Commit

Permalink
Fix save_hyperparameters() for top-level CLI (Lightning-AI#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent b73eb8e commit c80d260
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
15 changes: 15 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,21 @@ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser

# TODO: Make this more robust
# This hack strips away the subcommands from the top-level CLI
# to parse the file as if it was called as a script
known_commands = [
("finetune", "full"),
("finetune", "lora"),
("finetune", "adapter"),
("finetune", "adapter_v2"),
("pretrain",),
]
for known_command in known_commands:
unwanted = slice(1, 1 + len(known_command))
if tuple(sys.argv[unwanted]) == known_command:
sys.argv[unwanted] = []

parser = capture_parser(lambda: CLI(function))
config = parser.parse_args()
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,32 @@ def test_save_hyperparameters(tmp_path):
assert hparams["bar"] == 1


def _test_function2(out_dir: Path, foo: bool = False, bar: int = 1):
assert False, "I only exist as a signature, but I should not run."


@pytest.mark.parametrize("command", [
"any.py",
"litgpt finetune full",
"litgpt finetune lora",
"litgpt finetune adapter",
"litgpt finetune adapter_v2",
"litgpt pretrain",
])
def test_save_hyperparameters_known_commands(command, tmp_path):
from litgpt.utils import save_hyperparameters

with mock.patch("sys.argv", [*command.split(" "), "--out_dir", str(tmp_path), "--foo", "True"]):
save_hyperparameters(_test_function2, tmp_path)

with open(tmp_path / "hyperparameters.yaml", "r") as file:
hparams = yaml.full_load(file)

assert hparams["out_dir"] == str(tmp_path)
assert hparams["foo"] is True
assert hparams["bar"] == 1


def test_choose_logger(tmp_path):
from litgpt.utils import choose_logger

Expand Down

0 comments on commit c80d260

Please sign in to comment.