Skip to content

Commit

Permalink
enable sparse on windows and mac (dmlc#5277)
Browse files Browse the repository at this point in the history
* enable sparse on windows and mac

* that was stupid

* let's see what's going on..

* [Sparse] Fix the import error on Mac OS.

When using template functions that are defined in source files from DGL,
the loader of MacOS somehow cannot find their definitions. This fix simply
avoids depending on template functions from DGL headers.

With this fix, the sparse tests all pass on the MAC environment.

* ok this is the problem

* make errors clearer

* uh

* test

* Update __init__.py

* disabling ddp on windows

---------

Co-authored-by: czkkkkkk <[email protected]>
  • Loading branch information
BarclayII and czkkkkkk authored Feb 13, 2023
1 parent 465828c commit f62669b
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

dgl_linux_libs = 'build/libdgl.so, build/runUnitTests, python/dgl/_ffi/_cy3/core.cpython-*-x86_64-linux-gnu.so, build/tensoradapter/pytorch/*.so, build/dgl_sparse/*.so'
// Currently DGL on Windows is not working with Cython yet
dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll"
dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll, build\\dgl_sparse\\*.dll"

def init_git() {
sh 'rm -rf *'
Expand Down
11 changes: 1 addition & 10 deletions dgl_sparse/src/spspmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,7 @@ torch::Tensor _CSRMask(
auto val = TorchTensorToDGLArray(value);
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
runtime::NDArray ret;
if (val->dtype.bits == 32) {
ret = aten::CSRGetData<float>(csr, row, col, val, 0.);
} else if (val->dtype.bits == 64) {
ret = aten::CSRGetData<double>(csr, row, col, val, 0.);
} else {
TORCH_CHECK(
false, "Dtype of value for SpSpMM should be 32 or 64 bits but got: " +
std::to_string(val->dtype.bits));
}
runtime::NDArray ret = aten::CSRGetFloatingData(csr, row, col, val, 0.);
return DGLArrayToTorchTensor(ret);
}

Expand Down
30 changes: 30 additions & 0 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler);

/**
* @brief Get the data for each (row, col) pair, then index into the weights
* array.
*
* The operator supports matrix with duplicate entries but only one matched
* entry will be returned for each (row, col) pair. Support duplicate input
* (row, col) pairs.
*
* If some (row, col) pairs do not contain a valid non-zero elements to index
* into the weights array, DGL returns the value \a filler for that pair
* instead.
*
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @note This is the floating point number version of `CSRGetData`, which
removes the dtype template.
*
* @param mat Sparse matrix.
* @param rows Row index.
* @param cols Column index.
* @param weights The weights array.
* @param filler The value to return for row-column pairs not existent in the
* matrix.
* @return Data array. The i^th element is the data of (rows[i], cols[i])
*/
runtime::NDArray CSRGetFloatingData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, double filler);

/** @brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);

Expand Down
17 changes: 13 additions & 4 deletions python/dgl/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,25 @@
def load_dgl_sparse():
"""Load DGL C++ sparse library"""
version = torch.__version__.split("+", maxsplit=1)[0]
basename = f"libdgl_sparse_pytorch_{version}.so"

if sys.platform.startswith("linux"):
basename = f"libdgl_sparse_pytorch_{version}.so"
elif sys.platform.startswith("darwin"):
basename = f"libdgl_sparse_pytorch_{version}.dylib"
elif sys.platform.startswith("win"):
basename = f"dgl_sparse_pytorch_{version}.dll"
else:
raise NotImplementedError("Unsupported system: %s" % sys.platform)

dirname = os.path.dirname(libinfo.find_lib_path()[0])
path = os.path.join(dirname, "dgl_sparse", basename)
if not os.path.exists(path):
raise FileNotFoundError(f"Cannot find DGL C++ sparse library at {path}")

try:
torch.classes.load_library(path)
except Exception: # pylint: disable=W0703
raise ImportError("Cannot load DGL C++ sparse library")


# TODO(zhenkun): support other platforms
if sys.platform.startswith("linux"):
load_dgl_sparse()
load_dgl_sparse()
12 changes: 12 additions & 0 deletions src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,18 @@ NDArray CSRGetData(
return ret;
}

runtime::NDArray CSRGetFloatingData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, double filler) {
if (weights->dtype.bits == 64) {
return CSRGetData<double>(csr, rows, cols, weights, filler);
} else {
CHECK(weights->dtype.bits == 32)
<< "CSRGetFloatingData only supports 32 or 64 bits floaring number";
return CSRGetData<float>(csr, rows, cols, weights, filler);
}
}

template NDArray CSRGetData<float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from dgl.sparse import diag, DiagMatrix, identity

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import pytest
import torch

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(), (2,)])
@pytest.mark.parametrize("opname", ["add", "sub"])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from dgl.sparse import from_coo, power

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


def all_close_sparse(A, row, col, val, shape):
rowA, colA = A.coo()
Expand Down
6 changes: 0 additions & 6 deletions tests/pytorch/sparse/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
import subprocess
import sys

import pytest

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)

EXAMPLE_ROOT = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"..",
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
sparse_matrix_to_torch_sparse,
)

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import pytest
import torch

# TODO(#5013): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)

dgl_op_map = {
"sum": "sum",
"amin": "smin",
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

from dgl.sparse import from_coo, softmax

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from dgl.sparse import from_coo, from_csc, from_csr, val_like

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("row", [(0, 0, 1, 2), (0, 1, 2, 4)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from dgl.sparse import diag, from_coo

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/sparse/test_unary_op_diag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys

import backend as F
import pytest
import torch

from dgl.sparse import diag
Expand Down
5 changes: 0 additions & 5 deletions tests/pytorch/sparse/test_unary_op_sp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import sys

import backend as F
import pytest
import torch

from dgl.sparse import from_coo

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


def test_neg():
ctx = F.ctx()
Expand Down
6 changes: 6 additions & 0 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
if mode != 'cpu' and use_mask:
pytest.skip('Masked sampling only works on CPU.')
if use_ddp:
if os.name == 'nt':
pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0)
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
Expand Down Expand Up @@ -181,6 +183,8 @@ def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
if mode != 'cpu' and F.ctx() == F.cpu():
pytest.skip('UVA and GPU sampling require a GPU.')
if use_ddp:
if os.name == 'nt':
pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Expand Down Expand Up @@ -267,6 +271,8 @@ def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
if mode == 'uva' and isinstance(neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform):
pytest.skip("GlobalUniform don't support UVA yet.")
if use_ddp:
if os.name == 'nt':
pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/build_dgl.bat
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SET TEMP=%WORKSPACE%\tmp
SET TMPDIR=%WORKSPACE%\tmp

PUSHD build
cmake -DCMAKE_CXX_FLAGS="/DDGL_EXPORTS" -DUSE_AVX=ON -DUSE_OPENMP=ON -DBUILD_TORCH=ON -Dgtest_force_shared_crt=ON -DDMLC_FORCE_SHARED_CRT=ON -DBUILD_CPP_TEST=1 -DCMAKE_CONFIGURATION_TYPES="Release" .. -G "Visual Studio 16 2019" || EXIT /B 1
cmake -DCMAKE_CXX_FLAGS="/DDGL_EXPORTS" -DUSE_AVX=ON -DUSE_OPENMP=ON -DBUILD_TORCH=ON -Dgtest_force_shared_crt=ON -DDMLC_FORCE_SHARED_CRT=ON -DBUILD_CPP_TEST=1 -DCMAKE_CONFIGURATION_TYPES="Release" -DTORCH_PYTHON_INTERPS=python -DBUILD_SPARSE=ON .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild dgl.sln /m /nr:false || EXIT /B 1
COPY /Y Release\runUnitTests.exe .
POPD
Expand Down

0 comments on commit f62669b

Please sign in to comment.