We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d3504e1 commit e78d230Copy full SHA for e78d230
cuda_malloc.py
@@ -74,7 +74,8 @@ def cuda_malloc_supported():
74
module = importlib.util.module_from_spec(spec)
75
spec.loader.exec_module(module)
76
version = module.__version__
77
- if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
+
78
+ if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
79
args.cuda_malloc = cuda_malloc_supported()
80
except:
81
pass
0 commit comments