Skip to content

Commit

Permalink
[LTC] Add custom lazy tensor save function (pytorch#83294)
Browse files Browse the repository at this point in the history
We need a custom `save` function for checkpointing a lazy model, similar to what exists in PyTorch/XLA:
https://github.com/pytorch/xla/blob/3eb8a9d9eb4ebb0b064461c3704650241625654e/torch_xla/core/xla_model.py#L994
The purpose of this function is to move any lazy tensors to CPU before saving the checkpoint.

The way I implemented it was to create a general structure visitor, adapted from a function that we use quite often in Cerebras internal repositories. If there is a better tool already available in PyTorch that does the same things, I'm open to suggestions.

CC: @wconstab @Krovatkin @JackCaoG
Pull Request resolved: pytorch#83294
Approved by: https://github.com/wconstab
  • Loading branch information
antoniojkim authored and pytorchmergebot committed Aug 24, 2022
1 parent 3e6e0a1 commit 4eb02e8
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions torch/_lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch._C._lazy
from torch.utils._pytree import tree_flatten, tree_unflatten


def mark_step(device: str = "", wait=False):
Expand Down Expand Up @@ -34,3 +35,15 @@ def sync_multi(tensors, devices):
def get_tensor_id(tensor):
"""Return a unique id of the lazy tensor maintained by LTC"""
return torch._C._lazy._get_tensor_id(tensor)


def to_cpu(tensors, devices=None):
devices = devices or ["lazy"]

flattened, spec = tree_flatten(tensors)
sync_multi(flattened, devices)
return tree_unflatten([t.to("cpu") for t in flattened], spec)


def save(tensors, *args, **kwargs):
torch.save(to_cpu(tensors), *args, **kwargs)

0 comments on commit 4eb02e8

Please sign in to comment.