forked from facebookresearch/PyTorch-BigGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fileio_tests.py
130 lines (109 loc) · 5.24 KB
/
fileio_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import json
import tempfile
from typing import Any
from unittest import TestCase, main
import h5py
import numpy as np
import torch
from torchbiggraph.config import EntitySchema, RelationSchema, ConfigSchema
from torchbiggraph.fileio import DatasetIO, Mapping, ConfigMetadataProvider
class TestDatasetIO(TestCase):
# DatasetIO is only used wrapped in a BufferedReader as a source for
# torch.load, hence we test it only in this setting.
@staticmethod
def save_to(hf: h5py.File, name: str, data: Any) -> None:
with io.BytesIO() as bf:
torch.save(data, bf)
hf.create_dataset(
name, data=np.frombuffer(bf.getbuffer(), dtype=np.dtype("V1")))
@staticmethod
def load_from(hf: h5py.File, name: str) -> Any:
with io.BufferedReader(DatasetIO(hf[name])) as bf:
return torch.load(bf)
def test_scalars(self):
data = (["a", b"b"], {1: True, 0.2: {None, 4j}})
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
self.save_to(hf, "foo", data)
with h5py.File(bf.name, "r") as hf:
self.assertEqual(self.load_from(hf, "foo"), data)
def test_tensors(self):
data_foo = torch.zeros((100,), dtype=torch.int8)
data_bar = torch.ones((10, 10))
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
self.save_to(hf, "foo", data_foo)
self.save_to(hf, "bar", data_bar)
with h5py.File(bf.name, "r") as hf:
self.assertTrue(data_foo.equal(self.load_from(hf, "foo")))
self.assertTrue(data_bar.equal(self.load_from(hf, "bar")))
def test_bad_args(self):
# FIXME h5py-2.9 accepts just File(bf), allowing an un-Named TemporaryFile.
with tempfile.NamedTemporaryFile() as bf:
with h5py.File(bf.name, "w") as hf:
# Scalar array of "V<length>" type as suggested in the h5py doc.
data = np.void(b"data")
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("foo", data=data))
# One-dimensional array of uint8 type.
data = np.frombuffer(b"data", dtype=np.uint8)
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("bar", data=data))
# Two-dimensional array of bytes.
data = np.frombuffer(b"data", dtype=np.dtype("V1")).reshape(2, 2)
with self.assertRaises(TypeError):
DatasetIO(hf.create_dataset("baz", data=data))
class TestMapping(TestCase):
def test_one_field(self):
m = Mapping("foo.bar.{field}", "{field}/ham/eggs", fields=["field"])
self.assertEqual(m.private_to_public.map("foo.bar.baz"), "baz/ham/eggs")
self.assertEqual(m.public_to_private.map("spam/ham/eggs"), "foo.bar.spam")
with self.assertRaises(ValueError):
m.private_to_public.map("f00.b4r.b4z")
with self.assertRaises(ValueError):
m.private_to_public.map("foo.bar")
with self.assertRaises(ValueError):
m.private_to_public.map("foo.bar.")
with self.assertRaises(ValueError):
m.private_to_public.map("foo.bar.baz.2")
with self.assertRaises(ValueError):
m.private_to_public.map("2.foo.bar.baz")
with self.assertRaises(ValueError):
m.public_to_private.map("sp4m/h4m/3gg5")
with self.assertRaises(ValueError):
m.public_to_private.map("ham/eggs")
with self.assertRaises(ValueError):
m.public_to_private.map("/ham/eggs")
with self.assertRaises(ValueError):
m.public_to_private.map("2/spam/ham/eggs")
with self.assertRaises(ValueError):
m.public_to_private.map("spam/ham/eggs/2")
def test_many_field(self):
m = Mapping("fo{field1}.{field2}ar.b{field3}z",
"sp{field3}m/{field2}am/egg{field1}",
fields=["field1", "field2", "field3"])
self.assertEqual(m.private_to_public.map("foo.bar.baz"), "spam/bam/eggo")
self.assertEqual(m.public_to_private.map("spam/ham/eggs"), "fos.har.baz")
class TestConfigMetadataProvider(TestCase):
def test_basic(self):
config = ConfigSchema(
entities={"e": EntitySchema(num_partitions=1)},
relations=[RelationSchema(name="r", lhs="e", rhs="e")],
dimension=1,
entity_path="foo", edge_paths=["bar"], checkpoint_path="baz")
metadata = ConfigMetadataProvider(config).get_checkpoint_metadata()
self.assertIsInstance(metadata, dict)
self.assertCountEqual(metadata.keys(), ["config/json"])
self.assertEqual(
config, ConfigSchema.from_dict(json.loads(metadata["config/json"])))
if __name__ == '__main__':
main()