Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Build-time helpers for FlashMLA."""
99 changes: 99 additions & 0 deletions build_tools/maca_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Mapping


@dataclass(frozen=True)
class MacaBuildEnv:
maca_path: Path | None
cuda_path: Path | None
maca_clang_path: Path | None
maca_lib_path: Path | None
cucc_path: Path | None


def _path_from_env(env: Mapping[str, str], name: str) -> Path | None:
value = env.get(name)
cleaned = value.strip() if value else None
return Path(cleaned).expanduser() if cleaned else None


def _candidate_file(path: Path | None, relative: str) -> Path | None:
return path / relative if path is not None else None


def _find_executable(
name: str,
candidates: list[Path | None],
which: Callable[[str], str | None] = shutil.which,
) -> Path | None:
for candidate in candidates:
if candidate is not None and candidate.is_file():
return candidate
resolved = which(name)
return Path(resolved) if resolved else None


def resolve_maca_build_env(
env: Mapping[str, str] | None = None,
which: Callable[[str], str | None] = shutil.which,
) -> MacaBuildEnv:
env = os.environ if env is None else env
maca_path = _path_from_env(env, "MACA_PATH")
cuda_path = (
_path_from_env(env, "CUDA_HOME")
or _path_from_env(env, "CUDA_PATH")
or (maca_path / "tools" / "cu-bridge" if maca_path else None)
)
maca_clang_path = _path_from_env(env, "MACA_CLANG_PATH") or (
maca_path / "mxgpu_llvm" / "bin" if maca_path else None
)
maca_lib_path = maca_path / "lib" if maca_path else None
cucc_path = _find_executable(
"cucc",
[
_candidate_file(cuda_path, "bin/cucc"),
_candidate_file(maca_clang_path, "cucc"),
],
which=which,
)
return MacaBuildEnv(
maca_path=maca_path,
cuda_path=cuda_path,
maca_clang_path=maca_clang_path,
maca_lib_path=maca_lib_path,
cucc_path=cucc_path,
)


def validate_maca_build_env(build_env: MacaBuildEnv) -> list[str]:
errors: list[str] = []
required_dirs = {
"MACA_PATH": build_env.maca_path,
"CUDA_HOME/CUDA_PATH": build_env.cuda_path,
"MACA_CLANG_PATH": build_env.maca_clang_path,
"MACA library directory": build_env.maca_lib_path,
}
for label, path in required_dirs.items():
if path is None:
errors.append(f"{label} is not configured")
elif not path.is_dir():
errors.append(f"{label} does not exist or is not a directory: {path}")

if build_env.cucc_path is None:
errors.append("cucc compiler was not found in CUDA_HOME/bin, MACA_CLANG_PATH, or PATH")
elif not build_env.cucc_path.is_file():
errors.append(f"cucc compiler path is not a file: {build_env.cucc_path}")

return errors


def format_maca_build_env_errors(errors: list[str]) -> str:
hint = (
"Set MACA_PATH to the MACA toolkit root, CUDA_HOME or CUDA_PATH to "
"$MACA_PATH/tools/cu-bridge, and MACA_CLANG_PATH to "
"$MACA_PATH/mxgpu_llvm/bin before building FlashMLA from source."
)
return "Invalid MACA build environment:\n- " + "\n- ".join(errors) + "\n" + hint
24 changes: 18 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
CUDA_HOME,
)

from build_tools.maca_env import (
format_maca_build_env_errors,
resolve_maca_build_env,
validate_maca_build_env,
)
Comment on lines +29 to +33

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since build_tools has been turned into a Python package by adding build_tools/__init__.py, setuptools.find_packages() will automatically detect and package it. This will result in build_tools being installed as a top-level package in the user's Python environment, polluting the global namespace.

To prevent this, please add "build_tools" to the exclude list of find_packages in setup.py (around line 340).



with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
Expand Down Expand Up @@ -152,25 +158,30 @@ def find_unmatched_cu_files(cu_dir, o_dir):
generator_flag = ["-DOLD_GENERATOR_PATH"]

check_if_cuda_home_none("flash_mla")
maca_build_env = resolve_maca_build_env()
maca_build_env_errors = validate_maca_build_env(maca_build_env)
if maca_build_env_errors:
raise RuntimeError(format_maca_build_env_errors(maca_build_env_errors))
cuda_home = str(maca_build_env.cuda_path)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if cuda_home is not None:
_, bare_metal_version = get_cuda_bare_metal_version(cuda_home)
if bare_metal_version < Version("11.6"):
raise RuntimeError(
"FlashAttention is only supported on CUDA 11.6 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if cuda_home is not None:
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

lib_dir = Path(CUDA_HOME).parent.parent / "lib"
libraries=["mcblas"]
extra_objects = ['{}/lib{}.so'.format(lib_dir, l) for l in libraries]
lib_dir = maca_build_env.maca_lib_path
libraries = ["mcblas"]
extra_objects = [str(lib_dir / f"lib{l}.so") for l in libraries]
# extra_objects.extend([f for f in obj_lists if f.endswith('.o')])

# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
Expand Down Expand Up @@ -335,6 +346,7 @@ def run(self):
"dist",
"docs",
"benchmarks",
"build_tools",
"flash_mla.egg-info",
)
),
Expand Down
87 changes: 87 additions & 0 deletions tests/test_maca_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from build_tools.maca_env import (
format_maca_build_env_errors,
resolve_maca_build_env,
validate_maca_build_env,
)


class MacaEnvTest(unittest.TestCase):
def test_resolve_maca_build_env_defaults_to_cu_bridge(self):
with TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
maca_path = tmp_path / "maca"
cuda_path = maca_path / "tools" / "cu-bridge"
clang_path = maca_path / "mxgpu_llvm" / "bin"
for path in (cuda_path / "bin", clang_path, maca_path / "lib"):
path.mkdir(parents=True)
cucc = cuda_path / "bin" / "cucc"
cucc.write_text("#!/bin/sh\n", encoding="utf-8")

build_env = resolve_maca_build_env({"MACA_PATH": str(maca_path)}, which=lambda _: None)

self.assertEqual(build_env.maca_path, maca_path)
self.assertEqual(build_env.cuda_path, cuda_path)
self.assertEqual(build_env.maca_clang_path, clang_path)
self.assertEqual(build_env.maca_lib_path, maca_path / "lib")
self.assertEqual(build_env.cucc_path, cucc)
self.assertEqual(validate_maca_build_env(build_env), [])

def test_validate_maca_build_env_reports_missing_paths(self):
build_env = resolve_maca_build_env({}, which=lambda _: None)

errors = validate_maca_build_env(build_env)

self.assertIn("MACA_PATH is not configured", errors)
self.assertTrue(any("cucc compiler was not found" in error for error in errors))
self.assertIn("Invalid MACA build environment", format_maca_build_env_errors(errors))

def test_resolve_maca_build_env_uses_path_cucc(self):
with TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
maca_path = tmp_path / "maca"
cuda_path = tmp_path / "cu-bridge"
clang_path = tmp_path / "clang"
path_cucc = tmp_path / "bin" / "cucc"
for path in (maca_path / "lib", cuda_path, clang_path, path_cucc.parent):
path.mkdir(parents=True)
path_cucc.write_text("#!/bin/sh\n", encoding="utf-8")

build_env = resolve_maca_build_env(
{
"MACA_PATH": str(maca_path),
"CUDA_PATH": str(cuda_path),
"MACA_CLANG_PATH": str(clang_path),
},
which=lambda _: str(path_cucc),
)

self.assertEqual(build_env.cucc_path, path_cucc)
self.assertEqual(validate_maca_build_env(build_env), [])

def test_resolve_maca_build_env_strips_whitespace(self):
with TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
maca_path = tmp_path / "maca"
for path in (maca_path / "tools" / "cu-bridge" / "bin", maca_path / "mxgpu_llvm" / "bin", maca_path / "lib"):
path.mkdir(parents=True)
cucc = maca_path / "tools" / "cu-bridge" / "bin" / "cucc"
cucc.write_text("#!/bin/sh\n", encoding="utf-8")

build_env = resolve_maca_build_env(
{"MACA_PATH": f" {maca_path} "},
which=lambda _: None,
)

self.assertEqual(build_env.maca_path, maca_path)
self.assertEqual(build_env.cucc_path, cucc)


if __name__ == "__main__":
unittest.main()