Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dtype handling in export_llama #9430

Merged
merged 1 commit into from
Mar 21, 2025

Conversation

jackzhxng
Copy link
Contributor

Differential Revision: D71515138

@jackzhxng jackzhxng requested a review from lucylq as a code owner March 19, 2025 23:59
Copy link

pytorch-bot bot commented Mar 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9430

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 62f1e9d with merge base a828307 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 19, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

jackzhxng added a commit to jackzhxng/executorch that referenced this pull request Mar 20, 2025
Summary: Pull Request resolved: pytorch#9430

Differential Revision: D71515138
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

@jackzhxng jackzhxng changed the title Refactor dtype handling in export_llama. Refactor dtype handling in export_llama Mar 20, 2025
jackzhxng added a commit to jackzhxng/executorch that referenced this pull request Mar 20, 2025
Summary: Pull Request resolved: pytorch#9430

Differential Revision: D71515138
jackzhxng added a commit to jackzhxng/executorch that referenced this pull request Mar 20, 2025
Summary: Pull Request resolved: pytorch#9430

Differential Revision: D71515138
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

jackzhxng added a commit to jackzhxng/executorch that referenced this pull request Mar 20, 2025
Summary:
No more converting from fp32 -> checkpoint dtype (fp16 or lower) -> back to dtype override (fp32), where we are losing precision on buffers. Also cleans up the entire dtype, now it only occurs outside of model.py, who's responsibility should just be for loading the model.


Differential Revision: D71515138
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

jackzhxng added a commit to jackzhxng/executorch that referenced this pull request Mar 20, 2025
Summary:
While it might make sense intuitively to have the dtype of the model be the dtype of the checkpoint, this isn't possible for all backends which only support some dtypes. We need to be explicit about the dtype of the model for this reason. No more intermediate conversion into the checkpoint dtype, which could cause precision loss in situations like these:

fp32 -> checkpoint dtype (fp16 or lower) -> back to dtype override (fp32), where we are losing precision on buffers that are instantiated in fp32 and downcast to fp16.


Differential Revision: D71515138
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

Summary:
While it might make sense intuitively to have the dtype of the model be the dtype of the checkpoint, this isn't possible for all backends which only support some dtypes. We need to be explicit about the dtype of the model for this reason. No more intermediate conversion into the checkpoint dtype, which could cause precision loss in situations like these:

fp32 -> checkpoint dtype (fp16 or lower) -> back to dtype override (fp32), where we are losing precision on buffers that are instantiated in fp32 and downcast to fp16.


Reviewed By: kimishpatel

Differential Revision: D71515138
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71515138

@facebook-github-bot facebook-github-bot merged commit 0c1c362 into pytorch:main Mar 21, 2025
78 of 82 checks passed
DannyYuyang-quic pushed a commit to CodeLinaro/executorch that referenced this pull request Apr 2, 2025
Differential Revision: D71515138

Pull Request resolved: pytorch#9430
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants