forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_nccl.py
191 lines (147 loc) · 7.23 KB
/
test_nccl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import unittest
import sys
import torch
import torch.cuda.nccl as nccl
import torch.cuda
from torch.testing._internal.common_utils import (TestCase, run_tests,
IS_WINDOWS, load_tests,
TEST_WITH_ROCM)
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
import re
HIP_VERSION = 0.0 if torch.version.hip is None else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
nGPUs = torch.cuda.device_count()
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = object # noqa: F811
datatypes = [torch.float, torch.bfloat16] if TEST_WITH_ROCM else [torch.float]
class TestNCCL(TestCase):
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
def test_unique_id(self, device):
uid = nccl.unique_id()
self.assertIsInstance(uid, bytes)
self.assertGreater(len(uid), 1)
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5, 'Skip NCCL tests for ROCm')
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@dtypes(*datatypes)
def test_broadcast(self, device, dtype):
expected = torch.zeros(128).uniform_().to(dtype=dtype)
tensors = [expected.cuda()]
for device in range(1, torch.cuda.device_count()):
tensors.append(torch.zeros(128, dtype=dtype, device=device))
nccl.broadcast(tensors)
for i in range(torch.cuda.device_count()):
self.assertEqual(tensors[i], expected)
# Test with tuple
tensors = [expected.cuda()]
for device in range(1, torch.cuda.device_count()):
tensors.append(torch.zeros(128, dtype=dtype, device=device))
nccl.broadcast(tuple(tensors))
for i in range(torch.cuda.device_count()):
self.assertEqual(tensors[i], expected)
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5, 'Skip NCCL tests for ROCm')
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@dtypes(*datatypes)
def test_reduce(self, device, dtype):
cpu_tensors = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
expected = torch.zeros(128, dtype=dtype)
for t in cpu_tensors:
expected.add_(t)
tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
nccl.reduce(tensors)
self.assertEqual(tensors[0], expected)
# Test with tuple
tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
nccl.reduce(tuple(tensors))
self.assertEqual(tensors[0], expected)
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@dtypes(*datatypes)
def test_all_reduce(self, device, dtype):
if TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16:
raise unittest.SkipTest("Skip bfloat16 test for ROCm < 3.5")
cpu_tensors = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
expected = torch.zeros(128, dtype=dtype)
for t in cpu_tensors:
expected.add_(t)
tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
nccl.all_reduce(tensors)
for tensor in tensors:
self.assertEqual(tensor, expected)
# Test with tuple.
tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs))
nccl.all_reduce(tensors)
for tensor in tensors:
self.assertEqual(tensor, expected)
# Test with set.
tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)}
nccl.all_reduce(tensors)
for tensor in tensors:
self.assertEqual(tensor, expected)
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5, 'Skip NCCL tests for ROCm')
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
def test_collective_errors(self, device):
t = torch.rand(10).cuda(0)
with self.assertRaisesRegex(TypeError, "Inputs should be a collection of tensors"):
nccl.all_reduce(t)
with self.assertRaisesRegex(TypeError, "Inputs should be a collection of tensors"):
nccl.reduce(t)
with self.assertRaisesRegex(TypeError, "Inputs should be a collection of tensors"):
nccl.broadcast(t)
with self.assertRaisesRegex(TypeError, "Inputs should be a collection of tensors"):
nccl.all_gather(t, t)
with self.assertRaisesRegex(TypeError, "Inputs should be a collection of tensors"):
nccl.reduce_scatter(t, t)
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5, 'Skip NCCL tests for ROCm')
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@dtypes(*datatypes)
def test_all_gather(self, device, dtype):
cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
expected = torch.cat(cpu_inputs, 0)
inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
outputs = [torch.zeros(128 * nGPUs, device=i, dtype=dtype)
for i in range(nGPUs)]
nccl.all_gather(inputs, outputs)
for tensor in outputs:
self.assertEqual(tensor, expected)
# Test with tuple.
inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
outputs = [torch.zeros(128 * nGPUs, device=i, dtype=dtype)
for i in range(nGPUs)]
nccl.all_gather(tuple(inputs), tuple(outputs))
for tensor in outputs:
self.assertEqual(tensor, expected)
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5, 'Skip NCCL tests for ROCm')
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@dtypes(*datatypes)
def test_reduce_scatter(self, device, dtype):
in_size = 32 * nGPUs
out_size = 32
cpu_inputs = [torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs)]
expected = torch.zeros(in_size, dtype=dtype)
for t in cpu_inputs:
expected.add_(t)
expected = expected.view(nGPUs, 32)
inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
outputs = [torch.zeros(out_size, device=i, dtype=dtype)
for i in range(nGPUs)]
nccl.reduce_scatter(inputs, outputs)
for i in range(nGPUs):
self.assertEqual(outputs[i], expected[i])
# Test with tuple
inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
outputs = [torch.zeros(out_size, device=i, dtype=dtype)
for i in range(nGPUs)]
nccl.reduce_scatter(tuple(inputs), tuple(outputs))
for i in range(nGPUs):
self.assertEqual(outputs[i], expected[i])
instantiate_device_type_tests(TestNCCL, globals(), only_for='cuda')
if __name__ == '__main__':
run_tests()