diff --git a/setup.py b/setup.py index a0e59cf86..06555d4b0 100644 --- a/setup.py +++ b/setup.py @@ -171,7 +171,8 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if has_flag("--distributed_adam", "APEX_DISTRIBUTED_ADAM"): if "--distributed_adam" in sys.argv: