forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
416 lines (293 loc) · 10.8 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
"""
This package adds support for CUDA tensor types, that implement the same
function as CPU tensors, but they utilize GPUs for computation.
It is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports CUDA.
:ref:`cuda-semantics` has more details about working with CUDA.
"""
import contextlib
import platform
import ctypes
import os
import torch
from multiprocessing.util import register_after_fork as _register_after_fork
_initialized = False
_in_bad_fork = False # this global is also used in torch.manual_seed
_original_pid = False
_cudart = None
def is_available():
"""Returns a bool indicating if CUDA is currently available."""
if (not hasattr(torch._C, '_cuda_isDriverSufficient') or
not torch._C._cuda_isDriverSufficient()):
return False
return torch._C._cuda_getDeviceCount() > 0
def _sleep(cycles):
torch._C._cuda_sleep(cycles)
def _load_cudart():
# First check the main program for CUDA symbols
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'cudaGetErrorName'):
return lib
raise RuntimeError(
"couldn't find libcudart. Make sure CUDA libraries are installed in a"
"default location, or that they're in {}."
.format('DYLD_LIBRARY_PATH' if platform.system() == 'Darwin' else
'LD_LIBRARY_PATH'))
def _check_driver():
if not hasattr(torch._C, '_cuda_isDriverSufficient'):
raise AssertionError("Torch not compiled with CUDA enabled")
if not torch._C._cuda_isDriverSufficient():
if torch._C._cuda_getDriverVersion() == 0:
# found no NVIDIA driver on the system
raise AssertionError("""
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx""")
else:
# TODO: directly link to the alternative bin that needs install
raise AssertionError("""
The NVIDIA driver on your system is too old (found version {}).
Please update your GPU driver by downloading and installing a new
version from the URL: http://www.nvidia.com/Download/index.aspx
Alternatively, go to: https://pytorch.org/binaries to install
a PyTorch version that has been compiled with your version
of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
def _lazy_init():
global _initialized, _cudart, _original_pid
if _initialized:
return
if _in_bad_fork:
from sys import version_info
if version_info < (3, 4):
msg = ("To use CUDA with multiprocessing, you must use Python "
"3.4+ and the 'spawn' start method")
else:
msg = ("To use CUDA with multiprocessing, you must use the "
"'spawn' start method")
raise RuntimeError(
"Cannot re-initialize CUDA in forked subprocess. " + msg)
_check_driver()
torch._C._cuda_init()
torch._C._cuda_sparse_init()
_cudart = _load_cudart()
_cudart.cudaGetErrorName.restype = ctypes.c_char_p
_cudart.cudaGetErrorString.restype = ctypes.c_char_p
_original_pid = os.getpid()
_initialized = True
def _after_fork(arg):
global _initialized, _in_bad_fork
if _initialized and _original_pid != os.getpid():
_initialized = False
_in_bad_fork = True
_CudaBase.__new__ = _lazy_new
_register_after_fork(_after_fork, _after_fork)
def cudart():
_lazy_init()
return _cudart
class device(object):
"""Context-manager that changes the selected device.
Arguments:
idx (int): device index to select. It's a no-op if this argument
is negative.
"""
def __init__(self, idx):
self.idx = idx
self.prev_idx = -1
def __enter__(self):
if self.idx is -1:
return
_lazy_init()
self.prev_idx = torch._C._cuda_getDevice()
if self.prev_idx != self.idx:
torch._C._cuda_setDevice(self.idx)
def __exit__(self, *args):
if self.prev_idx != self.idx:
torch._C._cuda_setDevice(self.prev_idx)
return False
class device_of(device):
"""Context-manager that changes the current device to that of given object.
You can use both tensors and storages as arguments. If a given object is
not allocated on a GPU, this is a no-op.
Arguments:
obj (Tensor or Storage): object allocated on the selected device.
"""
def __init__(self, obj):
idx = obj.get_device() if obj.is_cuda else -1
super(device_of, self).__init__(idx)
def set_device(device):
"""Sets the current device.
Usage of this function is discouraged in favor of :any:`device`. In most
cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
Arguments:
device (int): selected device. This function is a no-op if this
argument is negative.
"""
if device >= 0:
torch._C._cuda_setDevice(device)
@contextlib.contextmanager
def stream(stream):
"""Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued on a selected
stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
"""
if stream is None:
yield
return
prev_stream = current_stream()
torch._C._cuda_setStream(stream._cdata)
try:
yield
finally:
torch._C._cuda_setStream(prev_stream._cdata)
def device_count():
"""Returns the number of GPUs available."""
if is_available():
_lazy_init()
return torch._C._cuda_getDeviceCount()
else:
return 0
def current_device():
"""Returns the index of a currently selected device."""
_lazy_init()
return torch._C._cuda_getDevice()
def synchronize():
"""Waits for all kernels in all streams on current device to complete."""
_lazy_init()
return torch._C._cuda_synchronize()
def current_stream():
"""Returns a currently selected :class:`Stream`."""
_lazy_init()
return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream())
def current_blas_handle():
"""Returns cublasHandle_t pointer to current cuBLAS handle"""
return torch._C._cuda_getCurrentBlasHandle()
def _host_allocator():
_lazy_init()
return torch._C._cuda_cudaHostAllocator()
@contextlib.contextmanager
def _free_mutex():
torch._C._cuda_lock_mutex()
try:
yield
finally:
torch._C._cuda_unlock_mutex()
from .random import *
################################################################################
# Define Storage and Tensor classes
################################################################################
from ..tensor import _TensorBase
from ..storage import _StorageBase
def _dummy_type(name):
def init_err(self):
class_name = self.__class__.__name__
raise RuntimeError(
"Tried to instantiate dummy base class {}".format(class_name))
return type(storage_name, (object,), {"__init__": init_err})
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half']:
storage_name = 'Cuda{0}StorageBase'.format(t)
tensor_name = 'Cuda{0}TensorBase'.format(t)
torch._C.__dict__[storage_name] = _dummy_type(storage_name)
torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
@staticmethod
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
# We need this method only for lazy init, so we can remove it
del _CudaBase.__new__
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
class _CudaBase(object):
is_cuda = True
is_sparse = False
def type(self, *args, **kwargs):
with device(self.get_device()):
return super(_CudaBase, self).type(*args, **kwargs)
__new__ = _lazy_new
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
pass
class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
pass
class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
pass
class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
pass
class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
pass
class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
pass
class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
pass
class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
pass
class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type(cls):
return DoubleStorage
class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type(cls):
return FloatStorage
class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type(cls):
return LongStorage
class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type(cls):
return IntStorage
class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type(cls):
return ShortStorage
class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
def is_signed(self):
# TODO
return False
@classmethod
def storage_type(cls):
return CharStorage
class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
def is_signed(self):
return False
@classmethod
def storage_type(cls):
return ByteStorage
class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
def is_signed(self):
return True
@classmethod
def storage_type():
return HalfStorage
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)
torch._storage_classes.add(IntStorage)
torch._storage_classes.add(ShortStorage)
torch._storage_classes.add(CharStorage)
torch._storage_classes.add(ByteStorage)
torch._storage_classes.add(HalfStorage)
torch._tensor_classes.add(DoubleTensor)
torch._tensor_classes.add(FloatTensor)
torch._tensor_classes.add(LongTensor)
torch._tensor_classes.add(IntTensor)
torch._tensor_classes.add(ShortTensor)
torch._tensor_classes.add(CharTensor)
torch._tensor_classes.add(ByteTensor)
torch._tensor_classes.add(HalfTensor)
from . import sparse
from . import nvtx
from .streams import Stream, Event