Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass context_fn to activation_checkpoint_params #951

Closed
wants to merge 1 commit into from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Dec 3, 2024

Summary:
We use prepare_module from torchtnt in the model training, to wrap the model with activation checkpointing. We want to pass a customized context_fn (to skip some ops from being recomputed) to the checkpointing wrapper, so added the context_fn field in activation_checkpoint_params.

Test plan (run the training job e2e) is included in the original diff.

Reviewed By: yoyoyocmu

Differential Revision: D65360604

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D65360604

@y-sq y-sq changed the title Don't recompute scaling factor in activation checkpointing Pass context_fn to activation_checkpoint_params Dec 3, 2024
y-sq added a commit to y-sq/tnt that referenced this pull request Dec 3, 2024
Summary:

Add the policy "layer_based_auto_wrap_policy_float8_training". It skips the recompute of float8 scaling factor (a scaler) to improve the latency.

To enable it, change the config file like: P1690229394

Reviewed By: yoyoyocmu

Differential Revision: D65360604
@y-sq y-sq force-pushed the export-D65360604 branch from 8d28ecc to 48aa7d7 Compare December 3, 2024 01:17
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D65360604

Summary:

Add the policy "layer_based_auto_wrap_policy_float8_training". It skips the recompute of float8 scaling factor (a scaler) to improve the latency.

To enable it, change the config file like: P1690229394

Reviewed By: yoyoyocmu

Differential Revision: D65360604
@y-sq y-sq force-pushed the export-D65360604 branch from 48aa7d7 to 7c52c99 Compare December 3, 2024 01:18
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D65360604

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants