diff --git a/build_tools/__init__.py b/build_tools/__init__.py new file mode 100644 index 0000000..5cb0db3 --- /dev/null +++ b/build_tools/__init__.py @@ -0,0 +1 @@ +"""Build-time helpers for FlashMLA.""" diff --git a/build_tools/maca_env.py b/build_tools/maca_env.py new file mode 100644 index 0000000..e5926d7 --- /dev/null +++ b/build_tools/maca_env.py @@ -0,0 +1,99 @@ +import os +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Mapping + + +@dataclass(frozen=True) +class MacaBuildEnv: + maca_path: Path | None + cuda_path: Path | None + maca_clang_path: Path | None + maca_lib_path: Path | None + cucc_path: Path | None + + +def _path_from_env(env: Mapping[str, str], name: str) -> Path | None: + value = env.get(name) + cleaned = value.strip() if value else None + return Path(cleaned).expanduser() if cleaned else None + + +def _candidate_file(path: Path | None, relative: str) -> Path | None: + return path / relative if path is not None else None + + +def _find_executable( + name: str, + candidates: list[Path | None], + which: Callable[[str], str | None] = shutil.which, +) -> Path | None: + for candidate in candidates: + if candidate is not None and candidate.is_file(): + return candidate + resolved = which(name) + return Path(resolved) if resolved else None + + +def resolve_maca_build_env( + env: Mapping[str, str] | None = None, + which: Callable[[str], str | None] = shutil.which, +) -> MacaBuildEnv: + env = os.environ if env is None else env + maca_path = _path_from_env(env, "MACA_PATH") + cuda_path = ( + _path_from_env(env, "CUDA_HOME") + or _path_from_env(env, "CUDA_PATH") + or (maca_path / "tools" / "cu-bridge" if maca_path else None) + ) + maca_clang_path = _path_from_env(env, "MACA_CLANG_PATH") or ( + maca_path / "mxgpu_llvm" / "bin" if maca_path else None + ) + maca_lib_path = maca_path / "lib" if maca_path else None + cucc_path = _find_executable( + "cucc", + [ + _candidate_file(cuda_path, "bin/cucc"), + _candidate_file(maca_clang_path, "cucc"), + ], + which=which, + ) + return MacaBuildEnv( + maca_path=maca_path, + cuda_path=cuda_path, + maca_clang_path=maca_clang_path, + maca_lib_path=maca_lib_path, + cucc_path=cucc_path, + ) + + +def validate_maca_build_env(build_env: MacaBuildEnv) -> list[str]: + errors: list[str] = [] + required_dirs = { + "MACA_PATH": build_env.maca_path, + "CUDA_HOME/CUDA_PATH": build_env.cuda_path, + "MACA_CLANG_PATH": build_env.maca_clang_path, + "MACA library directory": build_env.maca_lib_path, + } + for label, path in required_dirs.items(): + if path is None: + errors.append(f"{label} is not configured") + elif not path.is_dir(): + errors.append(f"{label} does not exist or is not a directory: {path}") + + if build_env.cucc_path is None: + errors.append("cucc compiler was not found in CUDA_HOME/bin, MACA_CLANG_PATH, or PATH") + elif not build_env.cucc_path.is_file(): + errors.append(f"cucc compiler path is not a file: {build_env.cucc_path}") + + return errors + + +def format_maca_build_env_errors(errors: list[str]) -> str: + hint = ( + "Set MACA_PATH to the MACA toolkit root, CUDA_HOME or CUDA_PATH to " + "$MACA_PATH/tools/cu-bridge, and MACA_CLANG_PATH to " + "$MACA_PATH/mxgpu_llvm/bin before building FlashMLA from source." + ) + return "Invalid MACA build environment:\n- " + "\n- ".join(errors) + "\n" + hint diff --git a/setup.py b/setup.py index deb687e..c881c6a 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,12 @@ CUDA_HOME, ) +from build_tools.maca_env import ( + format_maca_build_env_errors, + resolve_maca_build_env, + validate_maca_build_env, +) + with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -152,10 +158,15 @@ def find_unmatched_cu_files(cu_dir, o_dir): generator_flag = ["-DOLD_GENERATOR_PATH"] check_if_cuda_home_none("flash_mla") + maca_build_env = resolve_maca_build_env() + maca_build_env_errors = validate_maca_build_env(maca_build_env) + if maca_build_env_errors: + raise RuntimeError(format_maca_build_env_errors(maca_build_env_errors)) + cuda_home = str(maca_build_env.cuda_path) # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if cuda_home is not None: + _, bare_metal_version = get_cuda_bare_metal_version(cuda_home) if bare_metal_version < Version("11.6"): raise RuntimeError( "FlashAttention is only supported on CUDA 11.6 and above. " @@ -163,14 +174,14 @@ def find_unmatched_cu_files(cu_dir, o_dir): ) cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: + if cuda_home is not None: if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") - lib_dir = Path(CUDA_HOME).parent.parent / "lib" - libraries=["mcblas"] - extra_objects = ['{}/lib{}.so'.format(lib_dir, l) for l in libraries] + lib_dir = maca_build_env.maca_lib_path + libraries = ["mcblas"] + extra_objects = [str(lib_dir / f"lib{l}.so") for l in libraries] # extra_objects.extend([f for f in obj_lists if f.endswith('.o')]) # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as @@ -335,6 +346,7 @@ def run(self): "dist", "docs", "benchmarks", + "build_tools", "flash_mla.egg-info", ) ), diff --git a/tests/test_maca_env.py b/tests/test_maca_env.py new file mode 100644 index 0000000..ba4488d --- /dev/null +++ b/tests/test_maca_env.py @@ -0,0 +1,87 @@ +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from build_tools.maca_env import ( + format_maca_build_env_errors, + resolve_maca_build_env, + validate_maca_build_env, +) + + +class MacaEnvTest(unittest.TestCase): + def test_resolve_maca_build_env_defaults_to_cu_bridge(self): + with TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + maca_path = tmp_path / "maca" + cuda_path = maca_path / "tools" / "cu-bridge" + clang_path = maca_path / "mxgpu_llvm" / "bin" + for path in (cuda_path / "bin", clang_path, maca_path / "lib"): + path.mkdir(parents=True) + cucc = cuda_path / "bin" / "cucc" + cucc.write_text("#!/bin/sh\n", encoding="utf-8") + + build_env = resolve_maca_build_env({"MACA_PATH": str(maca_path)}, which=lambda _: None) + + self.assertEqual(build_env.maca_path, maca_path) + self.assertEqual(build_env.cuda_path, cuda_path) + self.assertEqual(build_env.maca_clang_path, clang_path) + self.assertEqual(build_env.maca_lib_path, maca_path / "lib") + self.assertEqual(build_env.cucc_path, cucc) + self.assertEqual(validate_maca_build_env(build_env), []) + + def test_validate_maca_build_env_reports_missing_paths(self): + build_env = resolve_maca_build_env({}, which=lambda _: None) + + errors = validate_maca_build_env(build_env) + + self.assertIn("MACA_PATH is not configured", errors) + self.assertTrue(any("cucc compiler was not found" in error for error in errors)) + self.assertIn("Invalid MACA build environment", format_maca_build_env_errors(errors)) + + def test_resolve_maca_build_env_uses_path_cucc(self): + with TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + maca_path = tmp_path / "maca" + cuda_path = tmp_path / "cu-bridge" + clang_path = tmp_path / "clang" + path_cucc = tmp_path / "bin" / "cucc" + for path in (maca_path / "lib", cuda_path, clang_path, path_cucc.parent): + path.mkdir(parents=True) + path_cucc.write_text("#!/bin/sh\n", encoding="utf-8") + + build_env = resolve_maca_build_env( + { + "MACA_PATH": str(maca_path), + "CUDA_PATH": str(cuda_path), + "MACA_CLANG_PATH": str(clang_path), + }, + which=lambda _: str(path_cucc), + ) + + self.assertEqual(build_env.cucc_path, path_cucc) + self.assertEqual(validate_maca_build_env(build_env), []) + + def test_resolve_maca_build_env_strips_whitespace(self): + with TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + maca_path = tmp_path / "maca" + for path in (maca_path / "tools" / "cu-bridge" / "bin", maca_path / "mxgpu_llvm" / "bin", maca_path / "lib"): + path.mkdir(parents=True) + cucc = maca_path / "tools" / "cu-bridge" / "bin" / "cucc" + cucc.write_text("#!/bin/sh\n", encoding="utf-8") + + build_env = resolve_maca_build_env( + {"MACA_PATH": f" {maca_path} "}, + which=lambda _: None, + ) + + self.assertEqual(build_env.maca_path, maca_path) + self.assertEqual(build_env.cucc_path, cucc) + + +if __name__ == "__main__": + unittest.main()