Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _maybe_compute_kjt_to_jt_dict to support JIT tracing #750

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zxpmirror1994
Copy link

Summary:
torch.split accepts a list of int (length_per_key). When doing JIT tracing, the list of int becomes a fixed value, making this piece of code unable to generalized.
#417
Use torch.tensor_split that accepts a tensor instead of tensor.split to prevent this issue

Differential Revision: D40662254

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Oct 25, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40662254

zxpmirror1994 pushed a commit to zxpmirror1994/torchrec that referenced this pull request Oct 25, 2022
)

Summary:
Pull Request resolved: pytorch#750

torch.split accepts a list of int (length_per_key). When doing JIT tracing, the list of int becomes a fixed value, making this piece of code unable to generalized.
pytorch#417
Use torch.tensor_split that accepts a tensor instead of tensor.split to prevent this issue. Also create a method to generate tensor-typed length_per_key for tensor_split's consumption

Differential Revision: D40662254

fbshipit-source-id: bbe920abf0e67383f1758cb5197867ff15132a47
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40662254

)

Summary:
Pull Request resolved: pytorch#750

torch.split accepts a list of int (length_per_key). When doing JIT tracing, the list of int becomes a fixed value, making this piece of code unable to generalized.
pytorch#417
Use torch.tensor_split that accepts a tensor instead of tensor.split to prevent this issue. Also create a method to generate tensor-typed length_per_key for tensor_split's consumption

Differential Revision: D40662254

fbshipit-source-id: 59639160876d64b9b7327673c21b318d57c582ab
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40662254

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants