Skip to content

Commit

Permalink
Have edge storage raise proper exception
Browse files Browse the repository at this point in the history
Summary: The interface prescribes raising CouldNotLoadData when a file is missing, we were raising RuntimeError.

Reviewed By: adamlerer

Differential Revision: D17571776

fbshipit-source-id: 04af21aa26a38fed235a8084c863fc6b5a4c0cee
  • Loading branch information
lw authored and facebook-github-bot committed Sep 26, 2019
1 parent 9c9e809 commit 53c9ce2
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions torchbiggraph/graph_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.

import errno
import json
import logging
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -394,12 +395,17 @@ def has_edges(

def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int:
file_path = self.get_edges_file(lhs_p, rhs_p)
if not file_path.is_file():
raise RuntimeError(f"{file_path} does not exist")
with h5py.File(file_path, "r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in edge file {file_path}")
return hf["rel"].len()
try:
with h5py.File(file_path, "r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in edge file {file_path}")
return hf["rel"].len()
except OSError as err:
# h5py refuses to make it easy to figure out what went wrong. The errno
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
if f"errno = {errno.ENOENT}" in str(err):
raise CouldNotLoadData() from err
raise err

def load_chunk_of_edges(
self,
Expand All @@ -409,36 +415,41 @@ def load_chunk_of_edges(
num_chunks: int = 1,
) -> EdgeList:
file_path = self.get_edges_file(lhs_p, rhs_p)
if not file_path.is_file():
raise RuntimeError(f"{file_path} does not exist")
with h5py.File(file_path, 'r') as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in edge file {file_path}")
lhs_ds = hf['lhs']
rhs_ds = hf['rhs']
rel_ds = hf['rel']

num_edges = rel_ds.len()
begin = int(chunk_idx * num_edges / num_chunks)
end = int((chunk_idx + 1) * num_edges / num_chunks)
chunk_size = end - begin

lhs = torch.empty((chunk_size,), dtype=torch.long)
rhs = torch.empty((chunk_size,), dtype=torch.long)
rel = torch.empty((chunk_size,), dtype=torch.long)

# Needed because https://github.com/h5py/h5py/issues/870.
if chunk_size > 0:
lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

lhsd = self.read_dynamic(hf, 'lhsd', begin, end)
rhsd = self.read_dynamic(hf, 'rhsd', begin, end)

return EdgeList(EntityList(lhs, lhsd),
EntityList(rhs, rhsd),
rel)
try:
with h5py.File(file_path, "r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in edge file {file_path}")
lhs_ds = hf["lhs"]
rhs_ds = hf["rhs"]
rel_ds = hf["rel"]

num_edges = rel_ds.len()
begin = int(chunk_idx * num_edges / num_chunks)
end = int((chunk_idx + 1) * num_edges / num_chunks)
chunk_size = end - begin

lhs = torch.empty((chunk_size,), dtype=torch.long)
rhs = torch.empty((chunk_size,), dtype=torch.long)
rel = torch.empty((chunk_size,), dtype=torch.long)

# Needed because https://github.com/h5py/h5py/issues/870.
if chunk_size > 0:
lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

lhsd = self.read_dynamic(hf, "lhsd", begin, end)
rhsd = self.read_dynamic(hf, "rhsd", begin, end)

return EdgeList(EntityList(lhs, lhsd),
EntityList(rhs, rhsd),
rel)
except OSError as err:
# h5py refuses to make it easy to figure out what went wrong. The errno
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
if f"errno = {errno.ENOENT}" in str(err):
raise CouldNotLoadData() from err
raise err

@staticmethod
def read_dynamic(
Expand Down

0 comments on commit 53c9ce2

Please sign in to comment.