Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LTC] Add custom lazy tensor save function (pytorch#83294)
We need a custom `save` function for checkpointing a lazy model, similar to what exists in PyTorch/XLA: https://github.com/pytorch/xla/blob/3eb8a9d9eb4ebb0b064461c3704650241625654e/torch_xla/core/xla_model.py#L994 The purpose of this function is to move any lazy tensors to CPU before saving the checkpoint. The way I implemented it was to create a general structure visitor, adapted from a function that we use quite often in Cerebras internal repositories. If there is a better tool already available in PyTorch that does the same things, I'm open to suggestions. CC: @wconstab @Krovatkin @JackCaoG Pull Request resolved: pytorch#83294 Approved by: https://github.com/wconstab
- Loading branch information