Skip to content

Commit

Permalink
Send around optim state in pickled form
Browse files Browse the repository at this point in the history
Summary: The use of DatasetIO to avoid copies was neat, but it doesn't extend well to a more generic interface. Since optimizers' state dicts are, like, 1% the size of their corresponding embedding table, removing it shouldn't introduce any meaningful performance regression.

Reviewed By: adamlerer

Differential Revision: D16961527

fbshipit-source-id: 1f48eb3c735a0d435ac61f39b3c13828aaa0353c
  • Loading branch information
lw authored and facebook-github-bot committed Aug 29, 2019
1 parent cf62ab1 commit 6655a46
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 152 deletions.
65 changes: 1 addition & 64 deletions test/test_checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,72 +6,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.

import io
import tempfile
from typing import Any
from unittest import TestCase, main

import h5py
import numpy as np
import torch

from torchbiggraph.checkpoint_storage import DatasetIO, TwoWayMapping


class TestDatasetIO(TestCase):

# DatasetIO is only used wrapped in a BufferedReader as a source for
# torch.load, hence we test it only in this setting.

@staticmethod
def save_to(hf: h5py.File, name: str, data: Any) -> None:
with io.BytesIO() as bf:
torch.save(data, bf)
hf.create_dataset(
name, data=np.frombuffer(bf.getbuffer(), dtype=np.dtype("V1")))

@staticmethod
def load_from(hf: h5py.File, name: str) -> Any:
with io.BufferedReader(DatasetIO(hf[name])) as bf:
return torch.load(bf)

def test_scalars(self):
data = (["a", b"b"], {1: True, 0.2: {None, 4j}})
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
self.save_to(hf, "foo", data)
with h5py.File(bf.name, "r") as hf:
self.assertEqual(self.load_from(hf, "foo"), data)

def test_tensors(self):
data_foo = torch.zeros((100,), dtype=torch.int8)
data_bar = torch.ones((10, 10))
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
self.save_to(hf, "foo", data_foo)
self.save_to(hf, "bar", data_bar)
with h5py.File(bf.name, "r") as hf:
self.assertTrue(data_foo.equal(self.load_from(hf, "foo")))
self.assertTrue(data_bar.equal(self.load_from(hf, "bar")))

def test_bad_args(self):
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
# Scalar array of "V<length>" type as suggested in the h5py doc.
data = np.void(b"data")
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("foo", data=data))
# One-dimensional array of uint8 type.
data = np.frombuffer(b"data", dtype=np.uint8)
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("bar", data=data))
# Two-dimensional array of bytes.
data = np.frombuffer(b"data", dtype=np.dtype("V1")).reshape(2, 2)
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("baz", data=data))
from torchbiggraph.checkpoint_storage import TwoWayMapping


class TestTwoWayMapping(TestCase):
Expand Down
78 changes: 61 additions & 17 deletions torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.

import io
import json
import logging
import multiprocessing as mp
Expand All @@ -16,9 +17,9 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import numpy as np
import torch
import torch.multiprocessing
from torch_extensions.rpc.rpc import _deserialize as torch_rpc_deserialize
from torch_extensions.rpc.rpc import _serialize as torch_rpc_serialize

from torchbiggraph.checkpoint_storage import (
load_entity_partition,
Expand All @@ -30,6 +31,7 @@
from torchbiggraph.config import ConfigSchema
from torchbiggraph.parameter_sharing import ParameterClient
from torchbiggraph.types import (
ByteTensorType,
EntityName,
FloatTensorType,
ModuleStateDict,
Expand Down Expand Up @@ -57,6 +59,14 @@ def noop() -> None:
pass


def bytes_to_bytetensor(data: bytes) -> ByteTensorType:
return torch.from_numpy(np.frombuffer(data, dtype=np.uint8))


def bytetensor_to_bytes(tensor: ByteTensorType) -> bytes:
return tensor.numpy().tobytes()


VERSION_FILE = "checkpoint_version.txt"
CONFIG_FILE = "config.json"

Expand All @@ -74,23 +84,29 @@ def store(
entity: EntityName,
part: Partition,
embs: FloatTensorType,
optim_state: Optional[OptimizerStateDict],
optim_state: Optional[bytes],
) -> None:
client = self._clients[part % len(self._clients)]
key = "%s_%s" % (entity, part)
client.store(key + "__embs", embs)
client.store(key + "__optim", torch_rpc_serialize(optim_state))
if optim_state is not None:
optim_state_tensor = bytes_to_bytetensor(optim_state)
client.store(key + "__optim", optim_state_tensor)

def get(
self,
entity: EntityName,
part: Partition,
) -> Tuple[FloatTensorType, OptimizerStateDict]:
) -> Tuple[FloatTensorType, Optional[bytes]]:
client = self._clients[part % len(self._clients)]
key = "%s_%s" % (entity, part)
embs = client.get(key + "__embs", shared=True)
assert embs is not None
optim_state = torch_rpc_deserialize(client.get(key + "__optim"))
optim_state_tensor = client.get(key + "__optim")
if optim_state_tensor is not None:
optim_state = bytetensor_to_bytes(optim_state_tensor)
else:
optim_state = None
return embs, optim_state

def join(self) -> None:
Expand All @@ -114,6 +130,25 @@ def get_checkpoint_metadata(self) -> Dict[str, Any]:
return {"config/json": self.json_config_dict}


def serialize_optim_state(
optim_state: Optional[OptimizerStateDict],
) -> Optional[bytes]:
if optim_state is None:
return None
with io.BytesIO() as bf:
torch.save(optim_state, bf)
return bf.getvalue()


def deserialize_optim_state(
serialized_optim_state: Optional[bytes],
) -> Optional[OptimizerStateDict]:
if serialized_optim_state is None:
return None
with io.BytesIO(serialized_optim_state) as bf:
return torch.load(bf)


class CheckpointManager:
"""Reads and writes checkpoint data to/from disk.
Expand Down Expand Up @@ -266,17 +301,18 @@ def write(
self._sync(file_path)

metadata = self.collect_metadata()
serialized_optim_state = serialize_optim_state(optim_state)

if self.partition_client is not None:
self.partition_client.store(entity, part, embs, optim_state)
self.partition_client.store(entity, part, embs, serialized_optim_state)
elif self.background:
if file_path in self.prefetched:
self.prefetched.pop(file_path)
future_res = self.pool.apply_async(
save_entity_partition, (file_path, embs, optim_state, metadata))
save_entity_partition, (file_path, embs, serialized_optim_state, metadata))
self.outstanding[file_path] = future_res
else:
save_entity_partition(file_path, embs, optim_state, metadata)
save_entity_partition(file_path, embs, serialized_optim_state, metadata)

def read(
self,
Expand All @@ -292,12 +328,17 @@ def read(

file_path = self._file_path(entity, part)
if (entity, part) in self.dirty and self.partition_client is not None:
return self.partition_client.get(entity, part)
if self.background:
embs, serialized_optim_state = self.partition_client.get(entity, part)
elif self.background:
self._sync(file_path)
if file_path in self.prefetched:
return self.prefetched.pop(file_path)
return load_entity_partition(file_path)
embs, serialized_optim_state = self.prefetched.pop(file_path)
else:
embs, serialized_optim_state = load_entity_partition(file_path)
else:
embs, serialized_optim_state = load_entity_partition(file_path)
optim_state = deserialize_optim_state(serialized_optim_state)
return embs, optim_state

def maybe_read(
self,
Expand Down Expand Up @@ -338,14 +379,17 @@ def write_model(
version = self._version(True)
file_path = os.path.join(self.path, f"model.v{version}.h5")
metadata = self.collect_metadata()
save_model(file_path, model_state, optim_state, metadata, MODEL_STATE_DICT_MAPPINGS)
serialized_optim_state = serialize_optim_state(optim_state)
save_model(file_path, model_state, serialized_optim_state, metadata, MODEL_STATE_DICT_MAPPINGS)

def read_model(
self,
) -> Tuple[Optional[ModuleStateDict], Optional[OptimizerStateDict]]:
version = self._version(False)
file_path = os.path.join(self.path, f"model.v{version}.h5")
return load_model(file_path, MODEL_STATE_DICT_MAPPINGS)
state_dict, serialized_optim_state = load_model(file_path, MODEL_STATE_DICT_MAPPINGS)
optim_state = deserialize_optim_state(serialized_optim_state)
return state_dict, optim_state

def maybe_read_model(
self,
Expand Down Expand Up @@ -378,13 +422,13 @@ def write_new_version(self, config: ConfigSchema) -> None:
for entity, econf in config.entities.items():
for part in range(self.rank, econf.num_partitions, self.num_machines):
logger.debug(f"Getting {entity} {part}")
embs, optim_state = \
embs, serialized_optim_state = \
self.partition_client.get(EntityName(entity), Partition(part))
logger.debug(f"Done getting {entity} {part}")
new_file_path = os.path.join(
self.path, f"embeddings_{entity}_{part}.v{new_version}.h5")
logger.debug(f"Saving {entity} {part} to {new_file_path}")
save_entity_partition(new_file_path, embs, optim_state, metadata)
save_entity_partition(new_file_path, embs, serialized_optim_state, metadata)
logger.debug(f"Done saving {entity} {part} to {new_file_path}")

def switch_to_new_version(self) -> None:
Expand Down
73 changes: 5 additions & 68 deletions torchbiggraph/checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,66 +29,6 @@
NP_VOID_DTYPE = np.dtype("V1")


class DatasetIO(io.RawIOBase):
"""A file-like proxy to a HDF5 dataset
Given a one-dimensional HFD5 dataset object whose elements are bytes, this
class wraps it and provides access to it through a file-like interface. The
"file" is open in binary mode (i.e. returns bytes objects rather than strs),
is read-only (writing could be easily supported, but isn't needed), seekable
and only offers "raw" (unbuffered) I/O. Users will probably want to wrap it
in a BufferedReader for better performance.
This is needed as a compatibility layer to enable features that only support
file-like objects (like torch.load) to read from HDF5-backed storage and
only load data as-needed (rather than pre-loading everything, as would be
necessary with BytesIO).
Writing isn't supported because (non-chunked) HDF5 datasets must be created
with their final size known in advance, which is usually not possible with
a file-like interface.
"""

def __init__(self, dataset: h5py.Dataset):
if dataset.dtype != NP_VOID_DTYPE:
raise TypeError("Dataset doesn't contain bytes")
if dataset.shape != (dataset.size,):
raise TypeError("Dataset isn't a one-dimensional array")
self.dataset = dataset
self.pos = 0

def readable(self) -> bool:
return True

def readinto(self, buffer: bytearray) -> int:
array = np.frombuffer(buffer, dtype=NP_VOID_DTYPE)
size = min(len(buffer), self.dataset.size - self.pos)
# Needed because https://github.com/h5py/h5py/issues/870.
if size > 0:
self.dataset.read_direct(array, np.s_[self.pos:self.pos + size], np.s_[:size])
self.pos += size
return size

def readall(self) -> bytes:
# We're supposed to implement this, but it doesn't appear to be needed.
raise io.UnsupportedOperation()

def seekable(self) -> bool:
return True

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence is io.SEEK_SET:
self.pos = offset
if whence is io.SEEK_CUR:
self.pos += offset
if whence is io.SEEK_END:
self.pos = self.dataset.size + offset
return self.pos

def tell(self) -> int:
return self.pos


# Names and values of metadata attributes for the HDF5 files.
FORMAT_VERSION_ATTR = "format_version"
FORMAT_VERSION = 1
Expand All @@ -115,21 +55,18 @@ def load_embeddings(hf: h5py.File) -> FloatTensorType:

def save_optimizer_state_dict(
hf: h5py.File,
state_dict: Optional[OptimizerStateDict],
state_dict: Optional[bytes],
) -> None:
if state_dict is None:
return
with io.BytesIO() as fobj:
torch.save(state_dict, fobj)
hf.create_dataset(OPTIMIZER_STATE_DICT_DATASET,
data=np.frombuffer(fobj.getbuffer(), dtype=NP_VOID_DTYPE))
hf.create_dataset(OPTIMIZER_STATE_DICT_DATASET,
data=np.frombuffer(state_dict, dtype=NP_VOID_DTYPE))


def load_optimizer_state_dict(hf: h5py.File) -> Optional[OptimizerStateDict]:
def load_optimizer_state_dict(hf: h5py.File) -> Optional[bytes]:
if OPTIMIZER_STATE_DICT_DATASET not in hf:
return None
with io.BufferedReader(DatasetIO(hf[OPTIMIZER_STATE_DICT_DATASET])) as fobj:
return torch.load(fobj)
return hf[OPTIMIZER_STATE_DICT_DATASET][...].tobytes()


class OneWayMapping:
Expand Down
7 changes: 4 additions & 3 deletions torchbiggraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
# To preserve and expose that information, at least to humans, we use more
# informative aliases for torch.Tensor. (PS: FloatTensor and LongTensor are in
# fact instances of the torch.tensortype metaclass).
CharTensorType = torch.Tensor
FloatTensorType = torch.Tensor
LongTensorType = torch.Tensor
ByteTensorType = torch.Tensor # uint8
CharTensorType = torch.Tensor # int8
FloatTensorType = torch.Tensor # float32
LongTensorType = torch.Tensor # int64


T = TypeVar("T")
Expand Down

0 comments on commit 6655a46

Please sign in to comment.