Skip to content

Commit 0aa185f

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Make torch.cuda.has_magma a build time check (pytorch#116299)
Perhaps originally one needed to query about GPU capability, but right now it's a simple check for a build time flag: https://github.com/pytorch/pytorch/blob/52f0457d7dc8e1ee6e25b6c97d11d3070d11263b/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L165-L171 Alternative, to avoid `at::hasMAGMA()` call one can implement it as follows: ```cpp const auto use_magma = caffe2::GetBuildOptions().at("USE_MAGMA"); return PyBool_FromLong(use_magma == "1"); ``` Make this check very similar to `_has_mkldnn` https://github.com/pytorch/pytorch/blob/0978482afa69118da1a986a4edec3acea01d2c6d/torch/csrc/Module.cpp#L1793-L1794 Test plan: Run `lldb -- python3 -c "import torch;print(torch.cuda.has_magma)"` and make sure it returns True and that `cuInit` is not called Pull Request resolved: pytorch#116299 Approved by: https://github.com/seemethere, https://github.com/albanD
1 parent 0edc348 commit 0aa185f

File tree

4 files changed

+5
-4
lines changed

4 files changed

+5
-4
lines changed

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,7 @@ has_mkl: _bool
12081208
_has_mps: _bool
12091209
has_lapack: _bool
12101210
_has_cuda: _bool
1211+
_has_magma: _bool
12111212
_has_mkldnn: _bool
12121213
_has_cudnn: _bool
12131214
has_spectral: _bool

torch/csrc/Module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,8 @@ Call this whenever a new thread is created in order to propagate values from
17891789
#endif
17901790

17911791
ASSERT_TRUE(set_module_attr("_has_cuda", has_cuda));
1792+
ASSERT_TRUE(
1793+
set_module_attr("_has_magma", at::hasMAGMA() ? Py_True : Py_False));
17921794
ASSERT_TRUE(set_module_attr("_has_mps", has_mps));
17931795
ASSERT_TRUE(
17941796
set_module_attr("_has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));

torch/csrc/cuda/Module.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,8 +1285,6 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
12851285
}
12861286
};
12871287

1288-
set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
1289-
12901288
auto num_gpus = c10::cuda::device_count();
12911289
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
12921290
for (const auto i : c10::irange(num_gpus)) {

torch/cuda/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def _maybe_exchange_device(device: int) -> int:
111111

112112

113113
has_half: bool = True
114-
# Global variables dynamically populated by native code
115-
has_magma: bool = False
114+
has_magma: bool = torch._C._has_magma
115+
116116
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
117117

118118

0 commit comments

Comments
 (0)