Skip to content

Commit

Permalink
first io package contribution
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Oct 2, 2019
1 parent d2c8ef6 commit b62e1b0
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 6 deletions.
1 change: 0 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
source=torch_geometric
omit=
torch_geometric/datasets/*
torch_geometric/io/*
torch_geometric/data/extract.py
torch_geometric/nn/data_parallel.py
[report]
Expand Down
8 changes: 8 additions & 0 deletions test/io/example1.off
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
OFF
4 2 0
0.0 0.0 0.0
0.0 1.0 0.0
1.0 0.0 0.0
1.0 1.0 0.0
3 0 1 2
3 1 2 3
7 changes: 7 additions & 0 deletions test/io/example2.off
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
OFF
4 1 0
0.0 0.0 0.0
0.0 1.0 0.0
1.0 0.0 0.0
1.0 1.0 0.0
4 0 1 2 3
31 changes: 31 additions & 0 deletions test/io/test_off.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import os.path as osp

import torch
from torch_geometric.data import Data
from torch_geometric.io import read_off, write_off


def test_read_off():
data = read_off(osp.join('test', 'io', 'example1.off'))
assert len(data) == 2
assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]
assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]]

data = read_off(osp.join('test', 'io', 'example2.off'))
assert len(data) == 2
assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]
assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]]


def test_write_off():
pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])
face = torch.tensor([[0, 1], [1, 2], [2, 3]])

path = osp.join('test', 'io', 'example.off')
write_off(Data(pos=pos, face=face), path)
data = read_off(path)
os.unlink(path)

assert data.pos.tolist() == pos.tolist()
assert data.face.tolist() == face.tolist()
3 changes: 2 additions & 1 deletion torch_geometric/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .ply import read_ply
from .obj import read_obj
from .sdf import read_sdf, parse_sdf
from .off import read_off, parse_off
from .off import read_off, parse_off, write_off
from .npz import read_npz, parse_npz

__all__ = [
Expand All @@ -17,6 +17,7 @@
'read_sdf',
'parse_sdf',
'read_off',
'write_off',
'parse_off',
'read_npz',
'parse_npz',
Expand Down
44 changes: 40 additions & 4 deletions torch_geometric/io/off.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import re

import torch
from torch._tensor_str import PRINT_OPTS
from torch_geometric.io import parse_txt_array
from torch_geometric.data import Data

REGEX = r'[\[\]\(\),]|tensor|(\s)\1{2,}'


def parse_off(src):
# Some files may contain a bug and do not have a carriage return after OFF.
Expand Down Expand Up @@ -33,14 +38,45 @@ def face_to_tri(face):
rect = rect.to(torch.int64)

if rect.numel() > 0:
first, second = rect[:, [0, 1, 2]], rect[:, [0, 2, 3]]
first, second = rect[:, [0, 1, 2]], rect[:, [1, 2, 3]]
return torch.cat([triangle, first, second], dim=0).t().contiguous()
else:
first, second = rect, rect

return torch.cat([triangle, first, second], dim=0).t().contiguous()
return triangle.t().contiguous()


def read_off(path):
r"""Reads an OFF (Object File Format) file, returning both the position of
nodes and their connectivity in a :class:`torch_geometric.data.Data`
object.
Args:
path (str): The path to the file.
"""
with open(path, 'r') as f:
src = f.read().split('\n')[:-1]
return parse_off(src)


def write_off(data, path):
r"""Writes a :class:`torch_geometric.data.Data` object to an OFF (Object
File Format) file.
Args:
data (:class:`torch_geometric.data.Data`): The data object.
path (str): The path to the file.
"""
num_nodes, num_faces = data.pos.size(0), data.face.size(1)

face = data.face.t()
num_vertices = torch.full((num_faces, 1), face.size(1), dtype=torch.long)
face = torch.cat([num_vertices, face], dim=-1)

threshold = PRINT_OPTS.threshold
torch.set_printoptions(threshold=float('inf'))
with open(path, 'w') as f:
f.write('OFF\n{} {} 0\n'.format(num_nodes, num_faces))
f.write(re.sub(REGEX, r'', data.pos.__repr__()))
f.write('\n')
f.write(re.sub(REGEX, r'', face.__repr__()))
f.write('\n')
torch.set_printoptions(threshold=threshold)

0 comments on commit b62e1b0

Please sign in to comment.